Add tenant context (#2596)

* add proper tenant context to background tasks

* update for new session logic

* remove unnecessary functions

* add additional tenant context

* update ports

* proper format / directory structure

* update ports

* ensure tenant context properly passed to ee bg tasks

* add user provisioning

* nit

* validated for multi tenant

* auth

* nit

* nit

* nit

* nit

* validate pruning

* evaluate integration tests

* at long last, validated celery beat

* nit: minor edge case patched

* minor

* validate update

* nit
This commit is contained in:
pablodanswer 2024-10-10 09:34:32 -07:00 committed by GitHub
parent 9be54a2b4c
commit f40c5ca9bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 1319 additions and 389 deletions

View File

@ -1,6 +1,6 @@
# A generic, single database configuration.
[alembic]
[DEFAULT]
# path to migration scripts
script_location = alembic
@ -47,7 +47,8 @@ prepend_sys_path = .
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
@ -106,3 +107,12 @@ formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
[alembic]
script_location = alembic
version_locations = %(script_location)s/versions
[schema_private]
script_location = alembic_tenants
version_locations = %(script_location)s/versions

View File

@ -1,21 +1,22 @@
from typing import Any
import asyncio
from logging.config import fileConfig
from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql import text
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
# Alembic Config object
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
@ -35,8 +36,7 @@ def get_schema_options() -> tuple[str, bool]:
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key] = value
x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", "public")
create_schema = x_args.get("create_schema", "true").lower() == "true"
return schema_name, create_schema
@ -46,11 +46,7 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
@ -59,7 +55,6 @@ def include_object(
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
@ -67,17 +62,18 @@ def run_migrations_offline() -> None:
Calls to context.execute() here emit the given string to the
script output.
"""
schema_name, _ = get_schema_options()
url = build_connection_string()
schema, _ = get_schema_options()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
version_table_schema=schema_name,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
@ -85,20 +81,30 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
schema, create_schema = get_schema_options()
schema_name, create_schema = get_schema_options()
if MULTI_TENANT and schema_name == "public":
raise ValueError(
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."
)
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
connection.execute(text(f'SET search_path TO "{schema}"'))
# Set search_path to the target schema
connection.execute(text(f'SET search_path TO "{schema_name}"'))
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
version_table_schema=schema,
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,
compare_server_default=True,
script_location=config.get_main_option("script_location"),
)
with context.begin_transaction():
@ -106,7 +112,6 @@ def do_run_migrations(connection: Connection) -> None:
async def run_async_migrations() -> None:
"""Run migrations in 'online' mode."""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
@ -119,7 +124,6 @@ async def run_async_migrations() -> None:
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())

View File

@ -20,7 +20,7 @@ depends_on: None = None
def upgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
sa.text('select id, chosen_assistants from "user"')
)
op.drop_column(
"user",
@ -37,7 +37,7 @@ def upgrade() -> None:
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
@ -46,7 +46,7 @@ def upgrade() -> None:
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
sa.text('select id, chosen_assistants from "user"')
)
op.drop_column(
"user",
@ -59,7 +59,7 @@ def downgrade() -> None:
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
),
{"chosen_assistants": chosen_assistants, "id": id},
)

View File

@ -0,0 +1,3 @@
These files are for public table migrations when operating with multi tenancy.
If you are not a Danswer developer, you can ignore this directory entirely.

View File

@ -0,0 +1,111 @@
import asyncio
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.schema import SchemaItem
from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import PublicBase
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [PublicBase.metadata]
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = build_connection_string()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
) # type: ignore
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -0,0 +1,24 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,24 @@
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "14a83a331951"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"user_tenant_mapping",
sa.Column("email", sa.String(), nullable=False),
sa.Column("tenant_id", sa.String(), nullable=False),
sa.UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
sa.UniqueConstraint("email", name="uq_email"),
schema="public",
)
def downgrade() -> None:
op.drop_table("user_tenant_mapping", schema="public")

View File

@ -34,6 +34,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
tenant_id: str | None = None
class UserUpdate(schemas.BaseUserUpdate):

View File

@ -26,11 +26,14 @@ from fastapi_users import schemas
from fastapi_users import UUIDIDMixin
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import JWTStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from sqlalchemy import select
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
@ -42,7 +45,9 @@ from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
@ -60,15 +65,21 @@ from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@ -136,8 +147,8 @@ def verify_email_is_invited(email: str) -> None:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@ -157,6 +168,20 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return "public"
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@ -221,6 +246,29 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
raise exceptions.UserAlreadyExists()
return user
async def on_after_login(
self,
user: User,
request: Request | None = None,
response: Response | None = None,
) -> None:
if response is None or not MULTI_TENANT:
return
tenant_id = get_tenant_id_for_email(user.email)
tenant_token = jwt.encode(
{"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256"
)
response.set_cookie(
key="tenant_details",
value=tenant_token,
httponly=True,
secure=WEB_DOMAIN.startswith("https"),
samesite="lax",
)
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
oauth_name: str,
@ -234,45 +282,111 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
verify_email_in_whitelist(account_email)
verify_email_domain(account_email)
user = await super().oauth_callback( # type: ignore
oauth_name=oauth_name,
access_token=access_token,
account_id=account_id,
account_email=account_email,
expires_at=expires_at,
refresh_token=refresh_token,
request=request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
await self.user_db.update(user, update_dict={"oidc_expiry": None})
# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login:
await self.user_db.update(
user,
update_dict={
"is_verified": is_verified_by_default,
"has_web_login": True,
},
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
)
user.is_verified = is_verified_by_default
user.has_web_login = True
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
return user
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
# Print a list of tables in the current database session
verify_email_in_whitelist(account_email, tenant_id)
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
oauth_account_dict = {
"oauth_name": oauth_name,
"access_token": access_token,
"account_id": account_id,
"account_email": account_email,
"expires_at": expires_at,
"refresh_token": refresh_token,
}
try:
# Attempt to get user by OAuth account
user = await self.get_by_oauth_account(oauth_name, account_id)
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = await self.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# If user not found by OAuth account or email, create a new user
except exceptions.UserNotExists:
password = self.password_helper.generate()
user_dict = {
"email": account_email,
"hashed_password": self.password_helper.hash(password),
"is_verified": is_verified_by_default,
}
user = await self.user_db.create(user_dict)
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
await self.on_after_register(user, request)
else:
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user, existing_oauth_account, oauth_account_dict
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(
user, update_dict={"oidc_expiry": oidc_expiry}
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login: # type: ignore
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"has_web_login": True,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True # type: ignore
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if (
user.oidc_expiry is not None # type: ignore
and not TRACK_EXTERNAL_IDP_EXPIRY
):
await self.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # type: ignore
if token:
current_tenant_id.reset(token)
return user
async def on_after_register(
self, user: User, request: Optional[Request] = None
@ -303,28 +417,51 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[User]:
try:
user = await self.get_by_email(credentials.username)
except exceptions.UserNotExists:
email = credentials.username
# Get tenant_id from mapping table
tenant_id = get_tenant_id_for_email(email)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
return None
if not user.has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
# Create a tenant-specific session
async with get_async_session_with_tenant(tenant_id) as tenant_session:
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
tenant_session, User
)
self.user_db = tenant_user_db
verified, updated_password_hash = self.password_helper.verify_and_update(
credentials.password, user.hashed_password
)
if not verified:
return None
# Proceed with authentication
try:
user = await self.get_by_email(email)
if updated_password_hash is not None:
await self.user_db.update(user, {"hashed_password": updated_password_hash})
except exceptions.UserNotExists:
self.password_helper.hash(credentials.password)
return None
return user
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
verified, updated_password_hash = self.password_helper.verify_and_update(
credentials.password, user.hashed_password
)
if not verified:
return None
if updated_password_hash is not None:
await self.user_db.update(
user, {"hashed_password": updated_password_hash}
)
return user
async def get_user_manager(
@ -339,20 +476,26 @@ cookie_transport = CookieTransport(
)
def get_jwt_strategy() -> JWTStrategy:
return JWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
)
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
strategy = DatabaseStrategy(
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy
auth_backend = AuthenticationBackend(
name="database",
name="jwt" if MULTI_TENANT else "database",
transport=cookie_transport,
get_strategy=get_database_strategy,
)
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
) # type: ignore
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
@ -366,9 +509,11 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
This way the login router does not need to be included
"""
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
logout_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
@ -415,8 +560,8 @@ async def optional_user_(
async def optional_user(
request: Request,
user: User | None = Depends(optional_fastapi_current_user),
db_session: Session = Depends(get_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"

View File

@ -23,6 +23,7 @@ from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.update import get_all_tenant_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
@ -70,7 +71,6 @@ def celery_task_postrun(
return
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
# logger.debug(f"Result: {retval}")
if state not in READY_STATES:
return
@ -437,48 +437,58 @@ celery_app.autodiscover_tasks(
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-vespa-sync": {
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
celery_app.conf.beat_schedule.update(
{
"check-for-connector-deletion-task": {
"task": "check_for_connector_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
celery_app.conf.beat_schedule.update(
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"check-for-prune": {
"task": "check_for_prune_task_2",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
celery_app.conf.beat_schedule.update(
"name": "check-for-prune",
"task": "check_for_prune_task_2",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"kombu-message-cleanup": {
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
}
)
celery_app.conf.beat_schedule.update(
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"monitor-vespa-sync": {
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for tenant_id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"args": (tenant_id,), # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration once
celery_app.conf.beat_schedule = beat_schedule

View File

@ -107,6 +107,7 @@ class RedisObjectHelper(ABC):
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
pass
@ -122,6 +123,7 @@ class RedisDocumentSet(RedisObjectHelper):
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
@ -146,7 +148,7 @@ class RedisDocumentSet(RedisObjectHelper):
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
@ -168,6 +170,7 @@ class RedisUserGroup(RedisObjectHelper):
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
@ -204,7 +207,7 @@ class RedisUserGroup(RedisObjectHelper):
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
@ -244,6 +247,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
@ -278,7 +282,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
@ -300,6 +304,7 @@ class RedisConnectorDeletion(RedisObjectHelper):
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
@ -336,6 +341,7 @@ class RedisConnectorDeletion(RedisObjectHelper):
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
@ -409,6 +415,7 @@ class RedisConnectorPruning(RedisObjectHelper):
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
last_lock_time = time.monotonic()
@ -442,6 +449,7 @@ class RedisConnectorPruning(RedisObjectHelper):
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,

View File

@ -23,7 +23,7 @@ from danswer.redis.redis_pool import get_redis_client
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_connector_deletion_task() -> None:
def check_for_connector_deletion_task(tenant_id: str | None) -> None:
r = get_redis_client()
lock_beat = r.lock(
@ -40,7 +40,7 @@ def check_for_connector_deletion_task() -> None:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(
cc_pair, db_session, r, lock_beat
cc_pair, db_session, r, lock_beat, tenant_id
)
except SoftTimeLimitExceeded:
task_logger.info(
@ -58,6 +58,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
@ -90,7 +91,9 @@ def try_generate_document_cc_pair_cleanup_tasks(
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
tasks_generated = rcd.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None

View File

@ -24,17 +24,21 @@ from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name="check_for_prune_task_2",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task_2() -> None:
def check_for_prune_task_2(tenant_id: str | None) -> None:
r = get_redis_client()
lock_beat = r.lock(
@ -47,11 +51,11 @@ def check_for_prune_task_2() -> None:
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
tasks_created = ccpair_pruning_generator_task_creation_helper(
cc_pair, db_session, r, lock_beat
cc_pair, db_session, tenant_id, r, lock_beat
)
if not tasks_created:
continue
@ -71,6 +75,7 @@ def check_for_prune_task_2() -> None:
def ccpair_pruning_generator_task_creation_helper(
cc_pair: ConnectorCredentialPair,
db_session: Session,
tenant_id: str | None,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
@ -101,13 +106,14 @@ def ccpair_pruning_generator_task_creation_helper(
if datetime.now(timezone.utc) < next_prune:
return None
return try_creating_prune_generator_task(cc_pair, db_session, r)
return try_creating_prune_generator_task(cc_pair, db_session, r, tenant_id)
def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
@ -140,7 +146,9 @@ def try_creating_prune_generator_task(
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
@ -153,14 +161,16 @@ def try_creating_prune_generator_task(
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
def connector_pruning_generator_task(connector_id: int, credential_id: int) -> None:
def connector_pruning_generator_task(
connector_id: int, credential_id: int, tenant_id: str | None
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
r = get_redis_client()
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
@ -218,7 +228,9 @@ def connector_pruning_generator_task(connector_id: int, credential_id: int) -> N
task_logger.info(
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcp.generate_tasks(celery_app, db_session, r, None)
tasks_generated = rcp.generate_tasks(
celery_app, db_session, r, None, tenant_id
)
if tasks_generated is None:
return None

View File

@ -1,7 +1,6 @@
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import task_logger
@ -11,7 +10,7 @@ from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
@ -26,7 +25,11 @@ from danswer.server.documents.models import ConnectorCredentialPairIdentifier
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task, document_id: str, connector_id: int, credential_id: int
self: Task,
document_id: str,
connector_id: int,
credential_id: int,
tenant_id: str | None,
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
@ -44,7 +47,7 @@ def document_by_cc_pair_cleanup_task(
(6) delete all relevant entries from postgres
"""
try:
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
chunks_affected = 0

View File

@ -38,6 +38,7 @@ from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import DocumentSet
@ -61,7 +62,7 @@ from danswer.utils.variable_functionality import noop_fallback
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_vespa_sync_task() -> None:
def check_for_vespa_sync_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
@ -77,8 +78,8 @@ def check_for_vespa_sync_task() -> None:
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id)
# check if any document sets are not synced
document_set_info = fetch_document_sets(
@ -86,7 +87,7 @@ def check_for_vespa_sync_task() -> None:
)
for document_set, _ in document_set_info:
try_generate_document_set_sync_tasks(
document_set, db_session, r, lock_beat
document_set, db_session, r, lock_beat, tenant_id
)
# check if any user groups are not synced
@ -101,7 +102,7 @@ def check_for_vespa_sync_task() -> None:
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat
usergroup, db_session, r, lock_beat, tenant_id
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
@ -120,7 +121,7 @@ def check_for_vespa_sync_task() -> None:
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None
) -> int | None:
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
@ -145,7 +146,9 @@ def try_generate_stale_document_sync_tasks(
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
continue
@ -169,7 +172,11 @@ def try_generate_stale_document_sync_tasks(
def try_generate_document_set_sync_tasks(
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
document_set: DocumentSet,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@ -193,7 +200,9 @@ def try_generate_document_set_sync_tasks(
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
tasks_generated = rds.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
@ -214,7 +223,11 @@ def try_generate_document_set_sync_tasks(
def try_generate_user_group_sync_tasks(
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
usergroup: UserGroup,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@ -236,7 +249,9 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
tasks_generated = rug.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
@ -471,7 +486,7 @@ def monitor_ccpair_pruning_taskset(
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task) -> None:
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
@ -516,7 +531,7 @@ def monitor_vespa_sync(self: Task) -> None:
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r)
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
@ -556,11 +571,13 @@ def monitor_vespa_sync(self: Task) -> None:
time_limit=60,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name

View File

@ -4,6 +4,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
@ -17,7 +18,7 @@ from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
@ -46,6 +47,7 @@ def _get_connector_runner(
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
tenant_id: str | None,
) -> ConnectorRunner:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
@ -87,8 +89,7 @@ def _get_connector_runner(
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
db_session: Session, index_attempt: IndexAttempt, tenant_id: str | None
) -> None:
"""
1. Get documents which are either new or updated from specified application
@ -129,6 +130,7 @@ def _run_indexing(
or (search_settings.status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
)
db_cc_pair = index_attempt.connector_credential_pair
@ -185,6 +187,7 @@ def _run_indexing(
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
all_connector_doc_ids: set[str] = set()
@ -212,7 +215,9 @@ def _run_indexing(
db_session.refresh(index_attempt)
if index_attempt.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError("Index Attempt was canceled")
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt.status}"
)
batch_description = []
for doc in doc_batch:
@ -373,12 +378,21 @@ def _run_indexing(
)
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
def _prepare_index_attempt(
db_session: Session, index_attempt_id: int, tenant_id: str | None
) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
# status and marking it as in_progress. This setting will be discarded
# after the next commit:
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
if tenant_id is not None:
# Explicitly set the search path for the given tenant
db_session.execute(text(f'SET search_path TO "{tenant_id}"'))
# Verify the search path was set correctly
result = db_session.execute(text("SHOW search_path"))
current_search_path = result.scalar()
logger.info(f"Current search path set to: {current_search_path}")
attempt = get_index_attempt(
db_session=db_session,
@ -401,12 +415,11 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
def run_indexing_entrypoint(
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
index_attempt_id: int,
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
if is_ee:
global_version.set_ee()
@ -416,26 +429,29 @@ def run_indexing_entrypoint(
IndexAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with Session(get_sqlalchemy_engine()) as db_session:
# make sure that it is valid to run this indexing attempt + mark it
# as in progress
attempt = _prepare_index_attempt(db_session, index_attempt_id)
with get_session_with_tenant(tenant_id) as db_session:
attempt = _prepare_index_attempt(db_session, index_attempt_id, tenant_id)
logger.info(
f"Indexing starting: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing starting for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
_run_indexing(db_session, attempt)
_run_indexing(db_session, attempt, tenant_id)
logger.info(
f"Indexing finished: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing finished for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
)

View File

@ -6,6 +6,8 @@ import dask
from dask.distributed import Client
from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.orm import Session
from danswer.background.indexing.dask_utils import ResourceLogger
@ -15,14 +17,16 @@ from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
from danswer.configs.constants import TENANT_ID_PREFIX
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
@ -153,13 +157,15 @@ def _mark_run_failed(
"""Main funcs"""
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
def create_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
) -> None:
"""Creates new indexing jobs for each connector / credential pair which is:
1. Enabled
2. `refresh_frequency` time has passed since the last indexing run for this pair
3. There is not already an ongoing indexing attempt for this pair
"""
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
ongoing: set[tuple[int | None, int]] = set()
for attempt_id in existing_jobs:
attempt = get_index_attempt(
@ -214,11 +220,12 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
def cleanup_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
tenant_id: str | None,
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
for attempt_id, job in existing_jobs.items():
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
@ -256,38 +263,41 @@ def cleanup_indexing_jobs(
)
# clean up in-progress jobs that were never completed
connectors = fetch_connectors(db_session)
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# If index attempt is canceled, stop the run
if index_attempt.status == IndexingStatus.FAILED:
existing_jobs[index_attempt.id].cancel()
# check to see if the job has been updated in last `timeout_hours` hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
existing_jobs[index_attempt.id].cancel()
try:
connectors = fetch_connectors(db_session)
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# If index attempt is canceled, stop the run
if index_attempt.status == IndexingStatus.FAILED:
existing_jobs[index_attempt.id].cancel()
# check to see if the job has been updated in last `timeout_hours` hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
existing_jobs[index_attempt.id].cancel()
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason="Indexing run frozen - no updates in the last three hours. "
"The run will be re-attempted at next scheduled indexing time.",
)
else:
# If job isn't known, simply mark it as failed
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason="Indexing run frozen - no updates in the last three hours. "
"The run will be re-attempted at next scheduled indexing time.",
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
else:
# If job isn't known, simply mark it as failed
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
except ProgrammingError:
logger.debug(f"No Connector Table exists for: {tenant_id}")
return existing_jobs_copy
@ -295,13 +305,15 @@ def kickoff_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
client: Client | SimpleJobClient,
secondary_client: Client | SimpleJobClient,
tenant_id: str | None,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
engine = get_sqlalchemy_engine()
current_session = get_session_with_tenant(tenant_id)
# Don't include jobs waiting in the Dask queue that just haven't started running
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
with Session(engine) as db_session:
with current_session as db_session:
# get_not_started_index_attempts orders its returned results from oldest to newest
# we must process attempts in a FIFO manner to prevent connector starvation
new_indexing_attempts = [
@ -332,7 +344,7 @@ def kickoff_indexing_jobs(
logger.warning(
f"Skipping index attempt as Connector has been deleted: {attempt}"
)
with Session(engine) as db_session:
with current_session as db_session:
mark_attempt_failed(
attempt, db_session, failure_reason="Connector is null"
)
@ -341,7 +353,7 @@ def kickoff_indexing_jobs(
logger.warning(
f"Skipping index attempt as Credential has been deleted: {attempt}"
)
with Session(engine) as db_session:
with current_session as db_session:
mark_attempt_failed(
attempt, db_session, failure_reason="Credential is null"
)
@ -352,6 +364,7 @@ def kickoff_indexing_jobs(
run = client.submit(
run_indexing_entrypoint,
attempt.id,
tenant_id,
attempt.connector_credential_pair_id,
global_version.is_ee_version(),
pure=False,
@ -363,6 +376,7 @@ def kickoff_indexing_jobs(
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
tenant_id,
attempt.connector_credential_pair_id,
global_version.is_ee_version(),
pure=False,
@ -398,42 +412,40 @@ def kickoff_indexing_jobs(
return existing_jobs_copy
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id="public") as session:
result = session.execute(
text(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
)
)
tenant_ids = [row[0] for row in result]
valid_tenants = [
tenant
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED:
cluster_primary = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
# there are warning about high memory usage + "Event loop unresponsive"
# which are not relevant to us since our workers are expected to use a
# lot of memory + involve CPU intensive tasks that will not relinquish
# the event loop
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
@ -449,7 +461,7 @@ def update_loop(
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
logger.notice("Startup complete. Waiting for indexing jobs...")
while True:
@ -458,24 +470,58 @@ def update_loop(
logger.debug(f"Running update, current UTC time: {start_time_utc}")
if existing_jobs:
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
logger.debug(
"Found existing indexing jobs: "
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
)
try:
with Session(get_sqlalchemy_engine()) as db_session:
check_index_swap(db_session)
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
create_indexing_jobs(existing_jobs=existing_jobs)
existing_jobs = kickoff_indexing_jobs(
existing_jobs=existing_jobs,
client=client_primary,
secondary_client=client_secondary,
)
tenants = get_all_tenant_ids()
for tenant_id in tenants:
try:
logger.debug(
f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
)
with get_session_with_tenant(tenant_id) as db_session:
check_index_swap(db_session=db_session)
if not MULTI_TENANT:
search_settings = get_current_search_settings(db_session)
if search_settings.provider_type is None:
logger.notice(
"Running a first inference to warm up embedding model"
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
logger.notice("First inference complete.")
tenant_jobs = existing_jobs.get(tenant_id, {})
tenant_jobs = cleanup_indexing_jobs(
existing_jobs=tenant_jobs, tenant_id=tenant_id
)
create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
tenant_jobs = kickoff_indexing_jobs(
existing_jobs=tenant_jobs,
client=client_primary,
secondary_client=client_secondary,
tenant_id=tenant_id,
)
existing_jobs[tenant_id] = tenant_jobs
except Exception as e:
logger.exception(
f"Failed to process tenant {tenant_id or 'default'}: {e}"
)
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)

View File

@ -429,3 +429,5 @@ SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"

View File

@ -31,6 +31,9 @@ DISABLED_GEN_AI_MSG = (
"You can still use Danswer as a search engine."
)
# Prefix used for all tenant ids
TENANT_ID_PREFIX = "tenant_"
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"

View File

@ -13,7 +13,7 @@ from sqlalchemy.future import select
from danswer.auth.schemas import UserRole
from danswer.db.engine import get_async_session
from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
@ -34,7 +34,7 @@ def get_default_admin_user_emails() -> list[str]:
async def get_user_count() -> int:
async with AsyncSession(get_sqlalchemy_async_engine()) as asession:
async with get_async_session_with_tenant() as asession:
stmt = select(func.count(User.id))
result = await asession.execute(stmt)
user_count = result.scalar()

View File

@ -390,6 +390,7 @@ def add_credential_to_connector(
)
db_session.add(association)
db_session.flush() # make sure the association has an id
db_session.refresh(association)
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(

View File

@ -1,16 +1,16 @@
import contextlib
import contextvars
import re
import threading
import time
from collections.abc import AsyncGenerator
from collections.abc import Generator
from contextlib import asynccontextmanager
from contextlib import contextmanager
from datetime import datetime
from typing import Any
from typing import ContextManager
import jwt
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
@ -39,7 +39,7 @@ from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@ -230,18 +230,8 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
# Context variable to store the current tenant ID
# This allows us to maintain tenant-specific context throughout the request lifecycle
# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups
# This context variable works in both synchronous and asynchronous contexts
# In async code, it's automatically carried across coroutines
# In sync code, it's managed per thread
current_tenant_id = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)
# Dependency to get the current tenant ID and set the context variable
# Dependency to get the current tenant ID
# If no token is present, uses the default schema for this use case
def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT:
@ -251,32 +241,31 @@ def get_current_tenant_id(request: Request) -> str:
token = request.cookies.get("tenant_details")
if not token:
current_value = current_tenant_id.get()
# If no token is present, use the default schema or handle accordingly
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
return tenant_id
return current_value
try:
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
tenant_id = payload.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=400, detail="Invalid token: tenant_id missing"
)
return current_tenant_id.get()
if not is_valid_schema_name(tenant_id):
raise ValueError("Invalid tenant ID format")
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
current_tenant_id.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token format")
except ValueError as e:
# Let the 400 error bubble up
raise HTTPException(status_code=400, detail=str(e))
except Exception:
return current_tenant_id.get()
except Exception as e:
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
def get_session_with_tenant(tenant_id: str | None = None) -> Session:
@asynccontextmanager
async def get_async_session_with_tenant(
tenant_id: str | None = None,
) -> AsyncGenerator[AsyncSession, None]:
if tenant_id is None:
tenant_id = current_tenant_id.get()
@ -284,20 +273,78 @@ def get_session_with_tenant(tenant_id: str | None = None) -> Session:
logger.error(f"Invalid tenant ID: {tenant_id}")
raise Exception("Invalid tenant ID")
engine = SqlEngine.get_engine()
session = Session(engine, expire_on_commit=False)
engine = get_sqlalchemy_async_engine()
async_session_factory = sessionmaker(
bind=engine, expire_on_commit=False, class_=AsyncSession
) # type: ignore
@event.listens_for(session, "after_begin")
def set_search_path(session: Session, transaction: Any, connection: Any) -> None:
connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id})
return session
async with async_session_factory() as session:
try:
# Set the search_path to the tenant's schema
await session.execute(text(f'SET search_path = "{tenant_id}"'))
except Exception as e:
logger.error(f"Error setting search_path: {str(e)}")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
else:
yield session
def get_session(
tenant_id: str = Depends(get_current_tenant_id),
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
if tenant_id is None:
tenant_id = current_tenant_id.get()
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Establish a raw connection without starting a transaction
with engine.connect() as connection:
# Access the raw DBAPI connection
dbapi_connection = connection.connection
# Execute SET search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path TO "{tenant_id}"')
# Optionally verify the search_path was set correctly
cursor.execute("SHOW search_path")
cursor.fetchone()
finally:
cursor.close()
# Proceed to create a session using the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
def get_session_generator_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
with get_session_with_tenant(tenant_id) as session:
yield session
def get_session() -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
tenant_id = current_tenant_id.get()
if tenant_id == "public" and MULTI_TENANT:
raise HTTPException(status_code=401, detail="User must authenticate")
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
@ -308,10 +355,9 @@ def get_session(
yield session
async def get_async_session(
tenant_id: str = Depends(get_current_tenant_id),
) -> AsyncGenerator[AsyncSession, None]:
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""Generate an async database session with the appropriate tenant schema set."""
tenant_id = current_tenant_id.get()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:
@ -324,7 +370,7 @@ async def get_async_session(
def get_session_context_manager() -> ContextManager[Session]:
"""Context manager for database sessions."""
return contextlib.contextmanager(get_session)()
return contextlib.contextmanager(get_session_generator_with_tenant)()
def get_session_factory() -> sessionmaker[Session]:

View File

@ -1763,3 +1763,23 @@ class UsageReport(Base):
requestor = relationship("User")
file = relationship("PGFileStore")
"""
Multi-tenancy related tables
"""
class PublicBase(DeclarativeBase):
__abstract__ = True
class UserTenantMapping(Base):
__tablename__ = "user_tenant_mapping"
__table_args__ = (
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
{"schema": "public"},
)
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)

View File

@ -137,6 +137,7 @@ def index_doc_batch_with_handler(
attempt_id: int | None,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> tuple[int, int]:
r = (0, 0)
try:
@ -148,6 +149,7 @@ def index_doc_batch_with_handler(
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
ignore_time_skip=ignore_time_skip,
tenant_id=tenant_id,
)
except Exception as e:
if INDEXING_EXCEPTION_LIMIT == 0:
@ -261,6 +263,7 @@ def index_doc_batch(
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
@ -324,6 +327,7 @@ def index_doc_batch(
if chunk.source_document.id in ctx.id_to_db_doc_map
else DEFAULT_BOOST
),
tenant_id=tenant_id,
)
for chunk in chunks_with_embeddings
]
@ -373,6 +377,7 @@ def build_indexing_pipeline(
chunker: Chunker | None = None,
ignore_time_skip: bool = False,
attempt_id: int | None = None,
tenant_id: str | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
@ -416,4 +421,5 @@ def build_indexing_pipeline(
ignore_time_skip=ignore_time_skip,
attempt_id=attempt_id,
db_session=db_session,
tenant_id=tenant_id,
)

View File

@ -75,6 +75,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
negative -> ranked lower.
"""
tenant_id: str | None = None
access: "DocumentAccess"
document_sets: set[str]
boost: int
@ -86,6 +87,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
boost: int,
tenant_id: str | None,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.model_dump()
return cls(
@ -93,6 +95,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
boost=boost,
tenant_id=tenant_id,
)

View File

@ -3,15 +3,21 @@ from collections.abc import Iterator
from contextlib import contextmanager
from typing import cast
from fastapi import HTTPException
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import is_valid_schema_name
from danswer.db.models import KVStore
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@ -28,6 +34,16 @@ class PgRedisKVStore(KeyValueStore):
def get_session(self) -> Iterator[Session]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
tenant_id = current_tenant_id.get()
if tenant_id == "public":
raise HTTPException(
status_code=401, detail="User must authenticate"
)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:

View File

@ -29,6 +29,7 @@ from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
@ -157,6 +158,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
)
# Will throw exception if an issue is found
verify_auth()
@ -169,11 +171,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
# fill up Postgres connection pools
await warm_up_connections()
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid()
if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid()
with Session(engine) as db_session:
setup_danswer(db_session)
# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:
setup_danswer(db_session)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield

View File

@ -22,6 +22,7 @@ from danswer.db.connector_credential_pair import (
update_connector_credential_pair_from_id,
)
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
@ -257,7 +258,9 @@ def prune_cc_pair(
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_prune_generator_task(cc_pair, db_session, r)
tasks_created = try_creating_prune_generator_task(
cc_pair, db_session, r, current_tenant_id.get()
)
if not tasks_created:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
@ -342,7 +345,7 @@ def sync_cc_pair(
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair_id),
kwargs=dict(cc_pair_id=cc_pair_id, tenant_id=current_tenant_id.get()),
)
return StatusResponse(

View File

@ -20,6 +20,7 @@ from danswer.db.connector_credential_pair import (
update_connector_credential_pair_from_id,
)
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.feedback import fetch_docs_ranked_by_boost
@ -146,6 +147,7 @@ def create_deletion_attempt_for_connector_id(
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
connector_id = connector_credential_pair_identifier.connector_id
credential_id = connector_credential_pair_identifier.credential_id
@ -196,6 +198,7 @@ def create_deletion_attempt_for_connector_id(
celery_app.send_task(
"check_for_connector_deletion_task",
priority=DanswerCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
if cc_pair.connector.source == DocumentSource.FILE:

View File

@ -2,17 +2,21 @@ import re
from datetime import datetime
from datetime import timezone
import jwt
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Body
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from psycopg2.errors import UniqueViolation
from pydantic import BaseModel
from sqlalchemy import Column
from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
@ -26,9 +30,12 @@ from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.engine import current_tenant_id
from danswer.db.engine import get_session
from danswer.db.models import AccessToken
from danswer.db.models import DocumentSet__User
@ -48,10 +55,13 @@ from danswer.server.manage.models import UserRoleUpdateRequest
from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot
from danswer.server.models import MinimalUserSnapshot
from danswer.server.utils import send_user_email_invite
from danswer.utils.logger import setup_logger
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
from ee.danswer.db.user_group import remove_curator_status__no_commit
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
logger = setup_logger()
@ -171,12 +181,33 @@ def bulk_invite_users(
raise HTTPException(
status_code=400, detail="Auth is disabled, cannot invite users"
)
tenant_id = current_tenant_id.get()
normalized_emails = []
for email in emails:
email_info = validate_email(email) # can raise EmailNotValidError
normalized_emails.append(email_info.normalized) # type: ignore
if MULTI_TENANT:
try:
add_users_to_tenant(normalized_emails, tenant_id)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise HTTPException(
status_code=400,
detail="User has already been invited to a Danswer organization",
)
raise
all_emails = list(set(normalized_emails) | set(get_invited_users()))
if MULTI_TENANT and ENABLE_EMAIL_INVITES:
try:
for email in all_emails:
send_user_email_invite(email, current_user)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
return write_invited_users(all_emails)
@ -187,6 +218,10 @@ def remove_invited_user(
) -> int:
user_emails = get_invited_users()
remaining_users = [user for user in user_emails if user != user_email.user_email]
tenant_id = current_tenant_id.get()
remove_users_from_tenant([user_email.user_email], tenant_id)
return write_invited_users(remaining_users)
@ -330,6 +365,35 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
return UserRoleResponse(role=user.role)
def get_current_token_expiration_jwt(
user: User | None, request: Request
) -> datetime | None:
if user is None:
return None
try:
# Get the JWT from the cookie
jwt_token = request.cookies.get("fastapiusersauth")
if not jwt_token:
logger.error("No JWT token found in cookies")
return None
# Decode the JWT
decoded_token = jwt.decode(jwt_token, options={"verify_signature": False})
# Get the 'exp' (expiration) claim from the token
exp = decoded_token.get("exp")
if exp:
return datetime.fromtimestamp(exp)
else:
logger.error("No 'exp' claim found in JWT")
return None
except Exception as e:
logger.error(f"Error decoding JWT: {e}")
return None
def get_current_token_creation(
user: User | None, db_session: Session
) -> datetime | None:
@ -357,6 +421,7 @@ def get_current_token_creation(
@router.get("/me")
def verify_user_logged_in(
request: Request,
user: User | None = Depends(optional_user),
db_session: Session = Depends(get_session),
) -> UserInfo:
@ -380,7 +445,9 @@ def verify_user_logged_in(
detail="Access denied. User's OIDC token has expired.",
)
token_created_at = get_current_token_creation(user, db_session)
token_created_at = (
None if MULTI_TENANT else get_current_token_creation(user, db_session)
)
user_info = UserInfo.from_model(
user,
current_token_created_at=token_created_at,

View File

@ -73,6 +73,7 @@ from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/chat")

View File

@ -1,7 +1,17 @@
import json
import smtplib
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Any
from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
from danswer.configs.app_configs import SMTP_SERVER
from danswer.configs.app_configs import SMTP_USER
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.models import User
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
@ -43,3 +53,28 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
masked_creds[key] = mask_string(val)
return masked_creds
def send_user_email_invite(user_email: str, current_user: User) -> None:
msg = MIMEMultipart()
msg["Subject"] = "Invitation to Join Danswer Workspace"
msg["To"] = user_email
msg["From"] = current_user.email
email_body = f"""
Hello,
You have been invited to join a workspace on Danswer.
To join the workspace, please do so at the following link:
{WEB_DOMAIN}/auth/login
Best regards,
The Danswer Team"""
msg.attach(MIMEText(email_body, "plain"))
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp_server:
smtp_server.starttls()
smtp_server.login(SMTP_USER, SMTP_PASS)
smtp_server.send_message(msg)

View File

@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
from danswer.chat.load_yamls import load_chat_yamls
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
@ -98,7 +99,8 @@ def setup_danswer(db_session: Session) -> None:
# Does the user need to trigger a reindexing to bring the document index
# into a good state, marked in the kv store
mark_reindex_flag(db_session)
if not MULTI_TENANT:
mark_reindex_flag(db_session)
# ensure Vespa is setup correctly
logger.notice("Verifying Document Index(s) is/are available.")

View File

@ -1,12 +1,12 @@
from datetime import timedelta
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.update import get_all_tenant_ids
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.chat import delete_chat_sessions_older_than
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
@ -32,6 +32,7 @@ from ee.danswer.external_permissions.permission_sync import (
run_external_group_permission_sync,
)
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@ -41,22 +42,26 @@ global_version.set_ee()
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def perform_ttl_management_task(
retention_limit_days: int, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
@ -67,16 +72,16 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
name="check_sync_external_doc_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_doc_permissions_task() -> None:
def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
"""Runs periodically to sync external permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_doc_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
)
@ -84,16 +89,16 @@ def check_sync_external_doc_permissions_task() -> None:
name="check_sync_external_group_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_group_permissions_task() -> None:
def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
"""Runs periodically to sync external group permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_group_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_group_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
)
@ -101,25 +106,33 @@ def check_sync_external_group_permissions_task() -> None:
name="check_ttl_management_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task() -> None:
def check_ttl_management_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
if MULTI_TENANT and tenant_id is not None:
token = current_tenant_id.set(tenant_id)
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(retention_limit_days=retention_limit_days),
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)
if token is not None:
current_tenant_id.reset(token)
@celery_app.task(
name="autogenerate_usage_report_task",
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task() -> None:
def autogenerate_usage_report_task(tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,
@ -130,22 +143,48 @@ def autogenerate_usage_report_task() -> None:
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"sync-external-doc-permissions": {
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "sync-external-doc-permissions",
"task": "check_sync_external_doc_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"sync-external-group-permissions": {
{
"name": "sync-external-group-permissions",
"task": "check_sync_external_group_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"autogenerate_usage_report": {
{
"name": "autogenerate_usage_report",
"task": "autogenerate_usage_report_task",
"schedule": timedelta(days=30), # TODO: change this to config flag
},
"check-ttl-management": {
{
"name": "check-ttl-management",
"task": "check_ttl_management_task",
"schedule": timedelta(hours=1),
},
**(celery_app.conf.beat_schedule or {}),
}
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for tenant_id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"args": (tenant_id,), # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration
celery_app.conf.beat_schedule = beat_schedule

View File

@ -2,9 +2,13 @@ def name_chat_ttl_task(retention_limit_days: int) -> str:
return f"chat_ttl_{retention_limit_days}_days"
def name_sync_external_doc_permissions_task(cc_pair_id: int) -> str:
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"
def name_sync_external_group_permissions_task(cc_pair_id: int) -> str:
def name_sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_group_permissions_task__{cc_pair_id}"

View File

@ -4,6 +4,7 @@ from httpx_oauth.clients.openid import OpenID
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import USER_AUTH_SECRET
@ -24,6 +25,7 @@ from ee.danswer.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.danswer.server.manage.standard_answer import router as standard_answer_router
from ee.danswer.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.danswer.server.query_and_chat.chat_backend import (
router as chat_router,
)
@ -53,6 +55,9 @@ def get_application() -> FastAPI:
application = get_application_base()
if MULTI_TENANT:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
application,

View File

@ -0,0 +1,60 @@
import logging
from collections.abc import Awaitable
from collections.abc import Callable
import jwt
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.db.engine import is_valid_schema_name
from shared_configs.configs import current_tenant_id
def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
@app.middleware("http")
async def set_tenant_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
logger.info(f"Request route: {request.url.path}")
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
else:
token = request.cookies.get("tenant_details")
if token:
try:
payload = jwt.decode(
token, SECRET_JWT_KEY, algorithms=["HS256"]
)
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if not is_valid_schema_name(tenant_id):
raise HTTPException(
status_code=400, detail="Invalid tenant ID format"
)
except jwt.InvalidTokenError:
tenant_id = POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(
f"Unexpected error in set_tenant_id_middleware: {str(e)}"
)
raise HTTPException(
status_code=500, detail="Internal server error"
)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
logger.info(f"Middleware set current_tenant_id to: {tenant_id}")
response = await call_next(request)
return response
except Exception as e:
logger.error(f"Error in tenant ID middleware: {str(e)}")
raise

View File

@ -8,8 +8,11 @@ from danswer.db.engine import get_session_with_tenant
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@ -19,9 +22,15 @@ router = APIRouter(prefix="/tenants")
def create_tenant(
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
) -> dict[str, str]:
try:
tenant_id = create_tenant_request.tenant_id
tenant_id = create_tenant_request.tenant_id
email = create_tenant_request.initial_admin_email
token = None
if user_owns_a_tenant(email):
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
)
try:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
@ -31,10 +40,14 @@ def create_tenant(
logger.info(f"Schema already exists for tenant {tenant_id}")
run_alembic_migrations(tenant_id)
token = current_tenant_id.set(tenant_id)
print("getting session", tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session)
logger.info(f"Tenant {tenant_id} created successfully")
add_users_to_tenant([email], tenant_id)
return {
"status": "success",
"message": f"Tenant {tenant_id} created successfully",
@ -44,3 +57,6 @@ def create_tenant(
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
current_tenant_id.reset(token)

View File

@ -8,7 +8,9 @@ from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import UserTenantMapping
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -61,3 +63,48 @@ def ensure_schema_exists(tenant_id: str) -> bool:
db_session.execute(stmt)
return True
return False
# For now, we're implementing a primitive mapping between users and tenants.
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant("public") as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
.first()
)
return result is not None
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant("public") as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception as e:
logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}")
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant("public") as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(emails),
UserTenantMapping.tenant_id == tenant_id,
)
.all()
)
for mapping in mappings_to_delete:
db_session.delete(mapping)
db_session.commit()
except Exception as e:
logger.exception(
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()

View File

@ -94,6 +94,7 @@ def generate_dummy_chunk(
),
document_sets={document_set for document_set in document_set_names},
boost=random.randint(-1, 1),
tenant_id="public",
)

View File

@ -1,3 +1,4 @@
import contextvars
import os
from typing import List
from urllib.parse import urlparse
@ -109,3 +110,5 @@ if CORS_ALLOWED_ORIGIN_ENV:
else:
# If the environment variable is empty, allow all origins
CORS_ALLOWED_ORIGIN = ["*"]
current_tenant_id = contextvars.ContextVar("current_tenant_id", default="public")

View File

@ -29,6 +29,7 @@ services:
- SMTP_PORT=${SMTP_PORT:-587} # For sending verification emails, if unspecified then defaults to '587'
- SMTP_USER=${SMTP_USER:-}
- SMTP_PASS=${SMTP_PASS:-}
- ENABLE_EMAIL_INVITES=${ENABLE_EMAIL_INVITES:-} # If enabled, will send users (using SMTP settings) an email to join the workspace
- EMAIL_FROM=${EMAIL_FROM:-}
- OAUTH_CLIENT_ID=${OAUTH_CLIENT_ID:-}
- OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-}

View File

@ -0,0 +1,45 @@
"use client";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import { REGISTRATION_URL } from "@/lib/constants";
import { Button } from "@tremor/react";
import Link from "next/link";
import { FiLogIn } from "react-icons/fi";
const Page = () => {
return (
<AuthFlowContainer>
<div className="flex flex-col space-y-6">
<h2 className="text-2xl font-bold text-text-900 text-center">
Account Not Found
</h2>
<p className="text-text-700 max-w-md text-center">
We couldn&apos;t find your account in our records. To access Danswer,
you need to either:
</p>
<ul className="list-disc text-left text-text-600 w-full pl-6 mx-auto">
<li>Be invited to an existing Danswer organization</li>
<li>Create a new Danswer organization</li>
</ul>
<div className="flex justify-center">
<Link
href={`${REGISTRATION_URL}/register`}
className="w-full max-w-xs"
>
<Button size="lg" icon={FiLogIn} color="indigo" className="w-full">
Create New Organization
</Button>
</Link>
</div>
<p className="text-sm text-text-500 text-center">
Have an account with a different email?{" "}
<Link href="/auth/login" className="text-indigo-600 hover:underline">
Sign in
</Link>
</p>
</div>
</AuthFlowContainer>
);
};
export default Page;

View File

@ -1,21 +1,49 @@
"use client";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import { Button } from "@tremor/react";
import Link from "next/link";
import { FiLogIn } from "react-icons/fi";
const Page = () => {
return (
<div className="flex flex-col items-center justify-center h-screen">
<div className="font-bold">
Unable to login, please try again and/or contact an administrator.
<AuthFlowContainer>
<div className="flex flex-col space-y-6 max-w-md mx-auto">
<h2 className="text-2xl font-bold text-text-900 text-center">
Authentication Error
</h2>
<p className="text-text-700 text-center">
We encountered an issue while attempting to log you in.
</p>
<div className="bg-red-50 border border-red-200 rounded-lg p-4 shadow-sm">
<h3 className="text-red-800 font-semibold mb-2">Possible Issues:</h3>
<ul className="space-y-2">
<li className="flex items-center text-red-700">
<div className="w-2 h-2 bg-red-500 rounded-full mr-2"></div>
Incorrect or expired login credentials
</li>
<li className="flex items-center text-red-700">
<div className="w-2 h-2 bg-red-500 rounded-full mr-2"></div>
Temporary authentication system disruption
</li>
<li className="flex items-center text-red-700">
<div className="w-2 h-2 bg-red-500 rounded-full mr-2"></div>
Account access restrictions or permissions
</li>
</ul>
</div>
<Link href="/auth/login" className="w-full">
<Button size="lg" icon={FiLogIn} color="indigo" className="w-full">
Return to Login Page
</Button>
</Link>
<p className="text-sm text-text-500 text-center">
We recommend trying again. If you continue to experience problems,
please reach out to your system administrator for assistance.
</p>
</div>
<Link href="/auth/login" className="w-fit">
<Button className="mt-4" size="xs" icon={FiLogIn}>
Back to login
</Button>
</Link>
</div>
</AuthFlowContainer>
);
};

View File

@ -6,11 +6,15 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
export const LoginText = () => {
const settings = useContext(SettingsContext);
if (!settings) {
throw new Error("SettingsContext is not available");
}
// if (!settings) {
// throw new Error("SettingsContext is not available");
// }
return (
<>Log In to {settings?.enterpriseSettings?.application_name || "Danswer"}</>
<>
Log In to{" "}
{(settings && settings?.enterpriseSettings?.application_name) ||
"Danswer"}
</>
);
};

View File

@ -14,6 +14,7 @@ import Link from "next/link";
import { Logo } from "@/components/Logo";
import { LoginText } from "./LoginText";
import { getSecondsUntilExpiration } from "@/lib/time";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
const Page = async ({
searchParams,
@ -51,7 +52,6 @@ const Page = async ({
if (authTypeMetadata?.requiresVerification && !currentUser.is_verified) {
return redirect("/auth/waiting-on-verification");
}
return redirect("/");
}
@ -70,46 +70,44 @@ const Page = async ({
}
return (
<main>
<AuthFlowContainer>
<div className="absolute top-10x w-full">
<HealthCheckBanner />
</div>
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
<div>
<Logo height={64} width={64} className="mx-auto w-fit" />
{authUrl && authTypeMetadata && (
<>
<h2 className="text-center text-xl text-strong font-bold mt-6">
<LoginText />
</h2>
<SignInButton
authorizeUrl={authUrl}
authType={authTypeMetadata?.authType}
/>
</>
)}
{authTypeMetadata?.authType === "basic" && (
<Card className="mt-4 w-96">
<div className="flex">
<Title className="mb-2 mx-auto font-bold">
<LoginText />
</Title>
</div>
<EmailPasswordForm />
<div className="flex">
<Text className="mt-4 mx-auto">
Don&apos;t have an account?{" "}
<Link href="/auth/signup" className="text-link font-medium">
Create an account
</Link>
</Text>
</div>
</Card>
)}
</div>
<div>
{authUrl && authTypeMetadata && (
<>
<h2 className="text-center text-xl text-strong font-bold">
<LoginText />
</h2>
<SignInButton
authorizeUrl={authUrl}
authType={authTypeMetadata?.authType}
/>
</>
)}
{authTypeMetadata?.authType === "basic" && (
<Card className="mt-4 w-96">
<div className="flex">
<Title className="mb-2 mx-auto font-bold">
<LoginText />
</Title>
</div>
<EmailPasswordForm />
<div className="flex">
<Text className="mt-4 mx-auto">
Don&apos;t have an account?{" "}
<Link href="/auth/signup" className="text-link font-medium">
Create an account
</Link>
</Text>
</div>
</Card>
)}
</div>
</main>
</AuthFlowContainer>
);
};

View File

@ -11,6 +11,12 @@ export const GET = async (request: NextRequest) => {
const response = await fetch(url.toString());
const setCookieHeader = response.headers.get("set-cookie");
if (response.status === 401) {
return NextResponse.redirect(
new URL("/auth/create-account", getDomain(request))
);
}
if (!setCookieHeader) {
return NextResponse.redirect(new URL("/auth/error", getDomain(request)));
}

View File

@ -10,6 +10,7 @@ import { EmailPasswordForm } from "../login/EmailPasswordForm";
import { Card, Title, Text } from "@tremor/react";
import Link from "next/link";
import { Logo } from "@/components/Logo";
import { CLOUD_ENABLED } from "@/lib/constants";
const Page = async () => {
// catch cases where the backend is completely unreachable here
@ -25,6 +26,9 @@ const Page = async () => {
} catch (e) {
console.log(`Some fetch failed for the login page - ${e}`);
}
if (CLOUD_ENABLED) {
return redirect("/auth/login");
}
// simply take the user to the home page if Auth is disabled
if (authTypeMetadata?.authType === "disabled") {

View File

@ -19,6 +19,8 @@ import { HeaderTitle } from "@/components/header/HeaderTitle";
import { Logo } from "@/components/Logo";
import { UserProvider } from "@/components/user/UserProvider";
import { ProviderContextProvider } from "@/components/chat_search/ProviderContext";
import { redirect } from "next/navigation";
import { headers } from "next/headers";
const inter = Inter({
subsets: ["latin"],
@ -56,8 +58,6 @@ export default async function RootLayout({
const combinedSettings = await fetchSettingsSS();
if (!combinedSettings) {
// Just display a simple full page error if fetching fails.
return (
<html lang="en" className={`${inter.variable} font-sans`}>
<Head>

View File

@ -0,0 +1,16 @@
import { Logo } from "../Logo";
export default function AuthFlowContainer({
children,
}: {
children: React.ReactNode;
}) {
return (
<div className="flex flex-col items-center justify-center min-h-screen bg-background">
<div className="w-full max-w-md p-8 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
<Logo width={70} height={70} />
{children}
</div>
</div>
);
}

View File

@ -40,7 +40,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
let settings: Settings;
if (!results[0].ok) {
if (results[0].status === 403) {
if (results[0].status === 403 || results[0].status === 401) {
settings = {
gpu_enabled: false,
chat_page_enabled: true,
@ -62,7 +62,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
let enterpriseSettings: EnterpriseSettings | null = null;
if (tasks.length > 1) {
if (!results[1].ok) {
if (results[1].status !== 403) {
if (results[1].status !== 403 && results[1].status !== 401) {
throw new Error(
`fetchEnterpriseSettingsSS failed: status=${results[1].status} body=${await results[1].text()}`
);

View File

@ -55,3 +55,7 @@ export const CUSTOM_ANALYTICS_ENABLED = process.env.CUSTOM_ANALYTICS_SECRET_KEY
export const DISABLE_LLM_DOC_RELEVANCE =
process.env.DISABLE_LLM_DOC_RELEVANCE?.toLowerCase() === "true";
export const CLOUD_ENABLED = process.env.NEXT_PUBLIC_CLOUD_ENABLED;
export const REGISTRATION_URL =
process.env.INTERNAL_URL || "http://127.0.0.1:3001";