mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-05 01:10:27 +02:00
Add tenant context (#2596)
* add proper tenant context to background tasks * update for new session logic * remove unnecessary functions * add additional tenant context * update ports * proper format / directory structure * update ports * ensure tenant context properly passed to ee bg tasks * add user provisioning * nit * validated for multi tenant * auth * nit * nit * nit * nit * validate pruning * evaluate integration tests * at long last, validated celery beat * nit: minor edge case patched * minor * validate update * nit
This commit is contained in:
parent
9be54a2b4c
commit
f40c5ca9bd
@ -1,6 +1,6 @@
|
|||||||
# A generic, single database configuration.
|
# A generic, single database configuration.
|
||||||
|
|
||||||
[alembic]
|
[DEFAULT]
|
||||||
# path to migration scripts
|
# path to migration scripts
|
||||||
script_location = alembic
|
script_location = alembic
|
||||||
|
|
||||||
@ -47,7 +47,8 @@ prepend_sys_path = .
|
|||||||
# version_path_separator = :
|
# version_path_separator = :
|
||||||
# version_path_separator = ;
|
# version_path_separator = ;
|
||||||
# version_path_separator = space
|
# version_path_separator = space
|
||||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
version_path_separator = os
|
||||||
|
# Use os.pathsep. Default configuration used for new projects.
|
||||||
|
|
||||||
# set to 'true' to search source files recursively
|
# set to 'true' to search source files recursively
|
||||||
# in each "version_locations" directory
|
# in each "version_locations" directory
|
||||||
@ -106,3 +107,12 @@ formatter = generic
|
|||||||
[formatter_generic]
|
[formatter_generic]
|
||||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
datefmt = %H:%M:%S
|
datefmt = %H:%M:%S
|
||||||
|
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
script_location = alembic
|
||||||
|
version_locations = %(script_location)s/versions
|
||||||
|
|
||||||
|
[schema_private]
|
||||||
|
script_location = alembic_tenants
|
||||||
|
version_locations = %(script_location)s/versions
|
||||||
|
@ -1,21 +1,22 @@
|
|||||||
|
from typing import Any
|
||||||
import asyncio
|
import asyncio
|
||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from danswer.db.engine import build_connection_string
|
|
||||||
from danswer.db.models import Base
|
|
||||||
from sqlalchemy import pool
|
from sqlalchemy import pool
|
||||||
from sqlalchemy.engine import Connection
|
from sqlalchemy.engine import Connection
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
|
||||||
from sqlalchemy.schema import SchemaItem
|
|
||||||
from sqlalchemy.sql import text
|
from sqlalchemy.sql import text
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
|
from danswer.db.engine import build_connection_string
|
||||||
|
from danswer.db.models import Base
|
||||||
|
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||||
|
|
||||||
# Alembic Config object
|
# Alembic Config object
|
||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
|
||||||
if config.config_file_name is not None and config.attributes.get(
|
if config.config_file_name is not None and config.attributes.get(
|
||||||
"configure_logger", True
|
"configure_logger", True
|
||||||
):
|
):
|
||||||
@ -35,8 +36,7 @@ def get_schema_options() -> tuple[str, bool]:
|
|||||||
for pair in arg.split(","):
|
for pair in arg.split(","):
|
||||||
if "=" in pair:
|
if "=" in pair:
|
||||||
key, value = pair.split("=", 1)
|
key, value = pair.split("=", 1)
|
||||||
x_args[key] = value
|
x_args[key.strip()] = value.strip()
|
||||||
|
|
||||||
schema_name = x_args.get("schema", "public")
|
schema_name = x_args.get("schema", "public")
|
||||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||||
return schema_name, create_schema
|
return schema_name, create_schema
|
||||||
@ -46,11 +46,7 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
|||||||
|
|
||||||
|
|
||||||
def include_object(
|
def include_object(
|
||||||
object: SchemaItem,
|
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||||
name: str,
|
|
||||||
type_: str,
|
|
||||||
reflected: bool,
|
|
||||||
compare_to: SchemaItem | None,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||||
return False
|
return False
|
||||||
@ -59,7 +55,6 @@ def include_object(
|
|||||||
|
|
||||||
def run_migrations_offline() -> None:
|
def run_migrations_offline() -> None:
|
||||||
"""Run migrations in 'offline' mode.
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
This configures the context with just a URL
|
This configures the context with just a URL
|
||||||
and not an Engine, though an Engine is acceptable
|
and not an Engine, though an Engine is acceptable
|
||||||
here as well. By skipping the Engine creation
|
here as well. By skipping the Engine creation
|
||||||
@ -67,17 +62,18 @@ def run_migrations_offline() -> None:
|
|||||||
Calls to context.execute() here emit the given string to the
|
Calls to context.execute() here emit the given string to the
|
||||||
script output.
|
script output.
|
||||||
"""
|
"""
|
||||||
|
schema_name, _ = get_schema_options()
|
||||||
url = build_connection_string()
|
url = build_connection_string()
|
||||||
schema, _ = get_schema_options()
|
|
||||||
|
|
||||||
context.configure(
|
context.configure(
|
||||||
url=url,
|
url=url,
|
||||||
target_metadata=target_metadata, # type: ignore
|
target_metadata=target_metadata, # type: ignore
|
||||||
literal_binds=True,
|
literal_binds=True,
|
||||||
include_object=include_object,
|
include_object=include_object,
|
||||||
dialect_opts={"paramstyle": "named"},
|
version_table_schema=schema_name,
|
||||||
version_table_schema=schema,
|
|
||||||
include_schemas=True,
|
include_schemas=True,
|
||||||
|
script_location=config.get_main_option("script_location"),
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
@ -85,20 +81,30 @@ def run_migrations_offline() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def do_run_migrations(connection: Connection) -> None:
|
def do_run_migrations(connection: Connection) -> None:
|
||||||
schema, create_schema = get_schema_options()
|
schema_name, create_schema = get_schema_options()
|
||||||
|
|
||||||
|
if MULTI_TENANT and schema_name == "public":
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
|
||||||
|
"Please specify a tenant-specific schema."
|
||||||
|
)
|
||||||
|
|
||||||
if create_schema:
|
if create_schema:
|
||||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
|
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||||
connection.execute(text("COMMIT"))
|
connection.execute(text("COMMIT"))
|
||||||
|
|
||||||
connection.execute(text(f'SET search_path TO "{schema}"'))
|
# Set search_path to the target schema
|
||||||
|
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||||
|
|
||||||
context.configure(
|
context.configure(
|
||||||
connection=connection,
|
connection=connection,
|
||||||
target_metadata=target_metadata, # type: ignore
|
target_metadata=target_metadata, # type: ignore
|
||||||
version_table_schema=schema,
|
include_object=include_object,
|
||||||
|
version_table_schema=schema_name,
|
||||||
include_schemas=True,
|
include_schemas=True,
|
||||||
compare_type=True,
|
compare_type=True,
|
||||||
compare_server_default=True,
|
compare_server_default=True,
|
||||||
|
script_location=config.get_main_option("script_location"),
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
@ -106,7 +112,6 @@ def do_run_migrations(connection: Connection) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def run_async_migrations() -> None:
|
async def run_async_migrations() -> None:
|
||||||
"""Run migrations in 'online' mode."""
|
|
||||||
connectable = create_async_engine(
|
connectable = create_async_engine(
|
||||||
build_connection_string(),
|
build_connection_string(),
|
||||||
poolclass=pool.NullPool,
|
poolclass=pool.NullPool,
|
||||||
@ -119,7 +124,6 @@ async def run_async_migrations() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def run_migrations_online() -> None:
|
def run_migrations_online() -> None:
|
||||||
"""Run migrations in 'online' mode."""
|
|
||||||
asyncio.run(run_async_migrations())
|
asyncio.run(run_async_migrations())
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ depends_on: None = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
conn = op.get_bind()
|
conn = op.get_bind()
|
||||||
existing_ids_and_chosen_assistants = conn.execute(
|
existing_ids_and_chosen_assistants = conn.execute(
|
||||||
sa.text("select id, chosen_assistants from public.user")
|
sa.text('select id, chosen_assistants from "user"')
|
||||||
)
|
)
|
||||||
op.drop_column(
|
op.drop_column(
|
||||||
"user",
|
"user",
|
||||||
@ -37,7 +37,7 @@ def upgrade() -> None:
|
|||||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
sa.text(
|
sa.text(
|
||||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
|
||||||
),
|
),
|
||||||
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
|
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
|
||||||
)
|
)
|
||||||
@ -46,7 +46,7 @@ def upgrade() -> None:
|
|||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
conn = op.get_bind()
|
conn = op.get_bind()
|
||||||
existing_ids_and_chosen_assistants = conn.execute(
|
existing_ids_and_chosen_assistants = conn.execute(
|
||||||
sa.text("select id, chosen_assistants from public.user")
|
sa.text('select id, chosen_assistants from "user"')
|
||||||
)
|
)
|
||||||
op.drop_column(
|
op.drop_column(
|
||||||
"user",
|
"user",
|
||||||
@ -59,7 +59,7 @@ def downgrade() -> None:
|
|||||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
sa.text(
|
sa.text(
|
||||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
'update "user" set chosen_assistants = :chosen_assistants where id = :id'
|
||||||
),
|
),
|
||||||
{"chosen_assistants": chosen_assistants, "id": id},
|
{"chosen_assistants": chosen_assistants, "id": id},
|
||||||
)
|
)
|
||||||
|
3
backend/alembic_tenants/README.md
Normal file
3
backend/alembic_tenants/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
These files are for public table migrations when operating with multi tenancy.
|
||||||
|
|
||||||
|
If you are not a Danswer developer, you can ignore this directory entirely.
|
111
backend/alembic_tenants/env.py
Normal file
111
backend/alembic_tenants/env.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import asyncio
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlalchemy.schema import SchemaItem
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from danswer.db.engine import build_connection_string
|
||||||
|
from danswer.db.models import PublicBase
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
if config.config_file_name is not None and config.attributes.get(
|
||||||
|
"configure_logger", True
|
||||||
|
):
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# add your model's MetaData object here
|
||||||
|
# for 'autogenerate' support
|
||||||
|
# from myapp import mymodel
|
||||||
|
# target_metadata = mymodel.Base.metadata
|
||||||
|
target_metadata = [PublicBase.metadata]
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
# can be acquired:
|
||||||
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
# ... etc.
|
||||||
|
|
||||||
|
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||||
|
|
||||||
|
|
||||||
|
def include_object(
|
||||||
|
object: SchemaItem,
|
||||||
|
name: str,
|
||||||
|
type_: str,
|
||||||
|
reflected: bool,
|
||||||
|
compare_to: SchemaItem | None,
|
||||||
|
) -> bool:
|
||||||
|
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = build_connection_string()
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata, # type: ignore
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection: Connection) -> None:
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata, # type: ignore
|
||||||
|
include_object=include_object,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_async_migrations() -> None:
|
||||||
|
"""In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
connectable = create_async_engine(
|
||||||
|
build_connection_string(),
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode."""
|
||||||
|
|
||||||
|
asyncio.run(run_async_migrations())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
24
backend/alembic_tenants/script.py.mako
Normal file
24
backend/alembic_tenants/script.py.mako
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = ${repr(up_revision)}
|
||||||
|
down_revision = ${repr(down_revision)}
|
||||||
|
branch_labels = ${repr(branch_labels)}
|
||||||
|
depends_on = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
@ -0,0 +1,24 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "14a83a331951"
|
||||||
|
down_revision = None
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"user_tenant_mapping",
|
||||||
|
sa.Column("email", sa.String(), nullable=False),
|
||||||
|
sa.Column("tenant_id", sa.String(), nullable=False),
|
||||||
|
sa.UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
|
||||||
|
sa.UniqueConstraint("email", name="uq_email"),
|
||||||
|
schema="public",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("user_tenant_mapping", schema="public")
|
@ -34,6 +34,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
|||||||
class UserCreate(schemas.BaseUserCreate):
|
class UserCreate(schemas.BaseUserCreate):
|
||||||
role: UserRole = UserRole.BASIC
|
role: UserRole = UserRole.BASIC
|
||||||
has_web_login: bool | None = True
|
has_web_login: bool | None = True
|
||||||
|
tenant_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(schemas.BaseUserUpdate):
|
class UserUpdate(schemas.BaseUserUpdate):
|
||||||
|
@ -26,11 +26,14 @@ from fastapi_users import schemas
|
|||||||
from fastapi_users import UUIDIDMixin
|
from fastapi_users import UUIDIDMixin
|
||||||
from fastapi_users.authentication import AuthenticationBackend
|
from fastapi_users.authentication import AuthenticationBackend
|
||||||
from fastapi_users.authentication import CookieTransport
|
from fastapi_users.authentication import CookieTransport
|
||||||
|
from fastapi_users.authentication import JWTStrategy
|
||||||
from fastapi_users.authentication import Strategy
|
from fastapi_users.authentication import Strategy
|
||||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||||
from fastapi_users.openapi import OpenAPIResponseType
|
from fastapi_users.openapi import OpenAPIResponseType
|
||||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import attributes
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.auth.invited_users import get_invited_users
|
from danswer.auth.invited_users import get_invited_users
|
||||||
@ -42,7 +45,9 @@ from danswer.configs.app_configs import DATA_PLANE_SECRET
|
|||||||
from danswer.configs.app_configs import DISABLE_AUTH
|
from danswer.configs.app_configs import DISABLE_AUTH
|
||||||
from danswer.configs.app_configs import EMAIL_FROM
|
from danswer.configs.app_configs import EMAIL_FROM
|
||||||
from danswer.configs.app_configs import EXPECTED_API_KEY
|
from danswer.configs.app_configs import EXPECTED_API_KEY
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||||
|
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||||
from danswer.configs.app_configs import SMTP_PASS
|
from danswer.configs.app_configs import SMTP_PASS
|
||||||
from danswer.configs.app_configs import SMTP_PORT
|
from danswer.configs.app_configs import SMTP_PORT
|
||||||
@ -60,15 +65,21 @@ from danswer.db.auth import get_access_token_db
|
|||||||
from danswer.db.auth import get_default_admin_user_emails
|
from danswer.db.auth import get_default_admin_user_emails
|
||||||
from danswer.db.auth import get_user_count
|
from danswer.db.auth import get_user_count
|
||||||
from danswer.db.auth import get_user_db
|
from danswer.db.auth import get_user_db
|
||||||
|
from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||||
|
from danswer.db.engine import get_async_session_with_tenant
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.db.models import AccessToken
|
from danswer.db.models import AccessToken
|
||||||
|
from danswer.db.models import OAuthAccount
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
|
from danswer.db.models import UserTenantMapping
|
||||||
from danswer.db.users import get_user_by_email
|
from danswer.db.users import get_user_by_email
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.telemetry import optional_telemetry
|
from danswer.utils.telemetry import optional_telemetry
|
||||||
from danswer.utils.telemetry import RecordType
|
from danswer.utils.telemetry import RecordType
|
||||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||||
|
from shared_configs.configs import current_tenant_id
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@ -136,8 +147,8 @@ def verify_email_is_invited(email: str) -> None:
|
|||||||
raise PermissionError("User not on allowed user whitelist")
|
raise PermissionError("User not on allowed user whitelist")
|
||||||
|
|
||||||
|
|
||||||
def verify_email_in_whitelist(email: str) -> None:
|
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
if not get_user_by_email(email, db_session):
|
if not get_user_by_email(email, db_session):
|
||||||
verify_email_is_invited(email)
|
verify_email_is_invited(email)
|
||||||
|
|
||||||
@ -157,6 +168,20 @@ def verify_email_domain(email: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tenant_id_for_email(email: str) -> str:
|
||||||
|
if not MULTI_TENANT:
|
||||||
|
return "public"
|
||||||
|
# Implement logic to get tenant_id from the mapping table
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
result = db_session.execute(
|
||||||
|
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||||
|
)
|
||||||
|
tenant_id = result.scalar_one_or_none()
|
||||||
|
if tenant_id is None:
|
||||||
|
raise exceptions.UserNotExists()
|
||||||
|
return tenant_id
|
||||||
|
|
||||||
|
|
||||||
def send_user_verification_email(
|
def send_user_verification_email(
|
||||||
user_email: str,
|
user_email: str,
|
||||||
token: str,
|
token: str,
|
||||||
@ -221,6 +246,29 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
raise exceptions.UserAlreadyExists()
|
raise exceptions.UserAlreadyExists()
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
async def on_after_login(
|
||||||
|
self,
|
||||||
|
user: User,
|
||||||
|
request: Request | None = None,
|
||||||
|
response: Response | None = None,
|
||||||
|
) -> None:
|
||||||
|
if response is None or not MULTI_TENANT:
|
||||||
|
return
|
||||||
|
|
||||||
|
tenant_id = get_tenant_id_for_email(user.email)
|
||||||
|
|
||||||
|
tenant_token = jwt.encode(
|
||||||
|
{"tenant_id": tenant_id}, SECRET_JWT_KEY, algorithm="HS256"
|
||||||
|
)
|
||||||
|
|
||||||
|
response.set_cookie(
|
||||||
|
key="tenant_details",
|
||||||
|
value=tenant_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=WEB_DOMAIN.startswith("https"),
|
||||||
|
samesite="lax",
|
||||||
|
)
|
||||||
|
|
||||||
async def oauth_callback(
|
async def oauth_callback(
|
||||||
self: "BaseUserManager[models.UOAP, models.ID]",
|
self: "BaseUserManager[models.UOAP, models.ID]",
|
||||||
oauth_name: str,
|
oauth_name: str,
|
||||||
@ -234,43 +282,109 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
associate_by_email: bool = False,
|
associate_by_email: bool = False,
|
||||||
is_verified_by_default: bool = False,
|
is_verified_by_default: bool = False,
|
||||||
) -> models.UOAP:
|
) -> models.UOAP:
|
||||||
verify_email_in_whitelist(account_email)
|
# Get tenant_id from mapping table
|
||||||
verify_email_domain(account_email)
|
try:
|
||||||
|
tenant_id = (
|
||||||
|
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
|
||||||
|
)
|
||||||
|
except exceptions.UserNotExists:
|
||||||
|
raise HTTPException(status_code=401, detail="User not found")
|
||||||
|
|
||||||
user = await super().oauth_callback( # type: ignore
|
if not tenant_id:
|
||||||
oauth_name=oauth_name,
|
raise HTTPException(status_code=401, detail="User not found")
|
||||||
access_token=access_token,
|
|
||||||
account_id=account_id,
|
token = None
|
||||||
account_email=account_email,
|
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||||
expires_at=expires_at,
|
token = current_tenant_id.set(tenant_id)
|
||||||
refresh_token=refresh_token,
|
# Print a list of tables in the current database session
|
||||||
request=request,
|
verify_email_in_whitelist(account_email, tenant_id)
|
||||||
associate_by_email=associate_by_email,
|
verify_email_domain(account_email)
|
||||||
is_verified_by_default=is_verified_by_default,
|
if MULTI_TENANT:
|
||||||
|
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||||
|
self.user_db = tenant_user_db
|
||||||
|
self.database = tenant_user_db
|
||||||
|
|
||||||
|
oauth_account_dict = {
|
||||||
|
"oauth_name": oauth_name,
|
||||||
|
"access_token": access_token,
|
||||||
|
"account_id": account_id,
|
||||||
|
"account_email": account_email,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Attempt to get user by OAuth account
|
||||||
|
user = await self.get_by_oauth_account(oauth_name, account_id)
|
||||||
|
|
||||||
|
except exceptions.UserNotExists:
|
||||||
|
try:
|
||||||
|
# Attempt to get user by email
|
||||||
|
user = await self.get_by_email(account_email)
|
||||||
|
if not associate_by_email:
|
||||||
|
raise exceptions.UserAlreadyExists()
|
||||||
|
|
||||||
|
user = await self.user_db.add_oauth_account(
|
||||||
|
user, oauth_account_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# If user not found by OAuth account or email, create a new user
|
||||||
|
except exceptions.UserNotExists:
|
||||||
|
password = self.password_helper.generate()
|
||||||
|
user_dict = {
|
||||||
|
"email": account_email,
|
||||||
|
"hashed_password": self.password_helper.hash(password),
|
||||||
|
"is_verified": is_verified_by_default,
|
||||||
|
}
|
||||||
|
|
||||||
|
user = await self.user_db.create(user_dict)
|
||||||
|
user = await self.user_db.add_oauth_account(
|
||||||
|
user, oauth_account_dict
|
||||||
|
)
|
||||||
|
await self.on_after_register(user, request)
|
||||||
|
|
||||||
|
else:
|
||||||
|
for existing_oauth_account in user.oauth_accounts:
|
||||||
|
if (
|
||||||
|
existing_oauth_account.account_id == account_id
|
||||||
|
and existing_oauth_account.oauth_name == oauth_name
|
||||||
|
):
|
||||||
|
user = await self.user_db.update_oauth_account(
|
||||||
|
user, existing_oauth_account, oauth_account_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||||
# re-authenticate that frequently, so by default this is disabled
|
# re-authenticate that frequently, so by default this is disabled
|
||||||
|
|
||||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
await self.user_db.update(
|
||||||
|
user, update_dict={"oidc_expiry": oidc_expiry}
|
||||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
)
|
||||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
|
||||||
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
|
|
||||||
await self.user_db.update(user, update_dict={"oidc_expiry": None})
|
|
||||||
|
|
||||||
# Handle case where user has used product outside of web and is now creating an account through web
|
# Handle case where user has used product outside of web and is now creating an account through web
|
||||||
if not user.has_web_login:
|
if not user.has_web_login: # type: ignore
|
||||||
await self.user_db.update(
|
await self.user_db.update(
|
||||||
user,
|
user,
|
||||||
update_dict={
|
{
|
||||||
"is_verified": is_verified_by_default,
|
"is_verified": is_verified_by_default,
|
||||||
"has_web_login": True,
|
"has_web_login": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
user.is_verified = is_verified_by_default
|
user.is_verified = is_verified_by_default
|
||||||
user.has_web_login = True
|
user.has_web_login = True # type: ignore
|
||||||
|
|
||||||
|
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||||
|
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||||
|
if (
|
||||||
|
user.oidc_expiry is not None # type: ignore
|
||||||
|
and not TRACK_EXTERNAL_IDP_EXPIRY
|
||||||
|
):
|
||||||
|
await self.user_db.update(user, {"oidc_expiry": None})
|
||||||
|
user.oidc_expiry = None # type: ignore
|
||||||
|
|
||||||
|
if token:
|
||||||
|
current_tenant_id.reset(token)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@ -303,13 +417,34 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
async def authenticate(
|
async def authenticate(
|
||||||
self, credentials: OAuth2PasswordRequestForm
|
self, credentials: OAuth2PasswordRequestForm
|
||||||
) -> Optional[User]:
|
) -> Optional[User]:
|
||||||
|
email = credentials.username
|
||||||
|
|
||||||
|
# Get tenant_id from mapping table
|
||||||
|
|
||||||
|
tenant_id = get_tenant_id_for_email(email)
|
||||||
|
if not tenant_id:
|
||||||
|
# User not found in mapping
|
||||||
|
self.password_helper.hash(credentials.password)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create a tenant-specific session
|
||||||
|
async with get_async_session_with_tenant(tenant_id) as tenant_session:
|
||||||
|
tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase(
|
||||||
|
tenant_session, User
|
||||||
|
)
|
||||||
|
self.user_db = tenant_user_db
|
||||||
|
|
||||||
|
# Proceed with authentication
|
||||||
try:
|
try:
|
||||||
user = await self.get_by_email(credentials.username)
|
user = await self.get_by_email(email)
|
||||||
|
|
||||||
except exceptions.UserNotExists:
|
except exceptions.UserNotExists:
|
||||||
self.password_helper.hash(credentials.password)
|
self.password_helper.hash(credentials.password)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not user.has_web_login:
|
has_web_login = attributes.get_attribute(user, "has_web_login")
|
||||||
|
|
||||||
|
if not has_web_login:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||||
@ -322,7 +457,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if updated_password_hash is not None:
|
if updated_password_hash is not None:
|
||||||
await self.user_db.update(user, {"hashed_password": updated_password_hash})
|
await self.user_db.update(
|
||||||
|
user, {"hashed_password": updated_password_hash}
|
||||||
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@ -339,20 +476,26 @@ cookie_transport = CookieTransport(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_jwt_strategy() -> JWTStrategy:
|
||||||
|
return JWTStrategy(
|
||||||
|
secret=USER_AUTH_SECRET,
|
||||||
|
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_database_strategy(
|
def get_database_strategy(
|
||||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||||
) -> DatabaseStrategy:
|
) -> DatabaseStrategy:
|
||||||
strategy = DatabaseStrategy(
|
return DatabaseStrategy(
|
||||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||||
)
|
)
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
auth_backend = AuthenticationBackend(
|
auth_backend = AuthenticationBackend(
|
||||||
name="database",
|
name="jwt" if MULTI_TENANT else "database",
|
||||||
transport=cookie_transport,
|
transport=cookie_transport,
|
||||||
get_strategy=get_database_strategy,
|
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||||
@ -366,9 +509,11 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
|||||||
This way the login router does not need to be included
|
This way the login router does not need to be included
|
||||||
"""
|
"""
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
get_current_user_token = self.authenticator.current_user_token(
|
get_current_user_token = self.authenticator.current_user_token(
|
||||||
active=True, verified=requires_verification
|
active=True, verified=requires_verification
|
||||||
)
|
)
|
||||||
|
|
||||||
logout_responses: OpenAPIResponseType = {
|
logout_responses: OpenAPIResponseType = {
|
||||||
**{
|
**{
|
||||||
status.HTTP_401_UNAUTHORIZED: {
|
status.HTTP_401_UNAUTHORIZED: {
|
||||||
@ -415,8 +560,8 @@ async def optional_user_(
|
|||||||
|
|
||||||
async def optional_user(
|
async def optional_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
user: User | None = Depends(optional_fastapi_current_user),
|
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
|
user: User | None = Depends(optional_fastapi_current_user),
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
versioned_fetch_user = fetch_versioned_implementation(
|
versioned_fetch_user = fetch_versioned_implementation(
|
||||||
"danswer.auth.users", "optional_user_"
|
"danswer.auth.users", "optional_user_"
|
||||||
|
@ -23,6 +23,7 @@ from danswer.background.celery.celery_redis import RedisConnectorPruning
|
|||||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||||
|
from danswer.background.update import get_all_tenant_ids
|
||||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||||
from danswer.configs.constants import DanswerCeleryPriority
|
from danswer.configs.constants import DanswerCeleryPriority
|
||||||
from danswer.configs.constants import DanswerRedisLocks
|
from danswer.configs.constants import DanswerRedisLocks
|
||||||
@ -70,7 +71,6 @@ def celery_task_postrun(
|
|||||||
return
|
return
|
||||||
|
|
||||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||||
# logger.debug(f"Result: {retval}")
|
|
||||||
|
|
||||||
if state not in READY_STATES:
|
if state not in READY_STATES:
|
||||||
return
|
return
|
||||||
@ -437,48 +437,58 @@ celery_app.autodiscover_tasks(
|
|||||||
#####
|
#####
|
||||||
# Celery Beat (Periodic Tasks) Settings
|
# Celery Beat (Periodic Tasks) Settings
|
||||||
#####
|
#####
|
||||||
celery_app.conf.beat_schedule = {
|
|
||||||
"check-for-vespa-sync": {
|
tenant_ids = get_all_tenant_ids()
|
||||||
|
|
||||||
|
tasks_to_schedule = [
|
||||||
|
{
|
||||||
|
"name": "check-for-vespa-sync",
|
||||||
"task": "check_for_vespa_sync_task",
|
"task": "check_for_vespa_sync_task",
|
||||||
"schedule": timedelta(seconds=5),
|
"schedule": timedelta(seconds=5),
|
||||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||||
},
|
},
|
||||||
}
|
|
||||||
celery_app.conf.beat_schedule.update(
|
|
||||||
{
|
{
|
||||||
"check-for-connector-deletion-task": {
|
"name": "check-for-connector-deletion",
|
||||||
"task": "check_for_connector_deletion_task",
|
"task": "check_for_connector_deletion_task",
|
||||||
# don't need to check too often, since we kick off a deletion initially
|
|
||||||
# during the API call that actually marks the CC pair for deletion
|
|
||||||
"schedule": timedelta(seconds=60),
|
"schedule": timedelta(seconds=60),
|
||||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||||
},
|
},
|
||||||
}
|
|
||||||
)
|
|
||||||
celery_app.conf.beat_schedule.update(
|
|
||||||
{
|
{
|
||||||
"check-for-prune": {
|
"name": "check-for-prune",
|
||||||
"task": "check_for_prune_task_2",
|
"task": "check_for_prune_task_2",
|
||||||
"schedule": timedelta(seconds=60),
|
"schedule": timedelta(seconds=10),
|
||||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||||
},
|
},
|
||||||
}
|
|
||||||
)
|
|
||||||
celery_app.conf.beat_schedule.update(
|
|
||||||
{
|
{
|
||||||
"kombu-message-cleanup": {
|
"name": "kombu-message-cleanup",
|
||||||
"task": "kombu_message_cleanup_task",
|
"task": "kombu_message_cleanup_task",
|
||||||
"schedule": timedelta(seconds=3600),
|
"schedule": timedelta(seconds=3600),
|
||||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||||
},
|
},
|
||||||
}
|
|
||||||
)
|
|
||||||
celery_app.conf.beat_schedule.update(
|
|
||||||
{
|
{
|
||||||
"monitor-vespa-sync": {
|
"name": "monitor-vespa-sync",
|
||||||
"task": "monitor_vespa_sync",
|
"task": "monitor_vespa_sync",
|
||||||
"schedule": timedelta(seconds=5),
|
"schedule": timedelta(seconds=5),
|
||||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||||
},
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build the celery beat schedule dynamically
|
||||||
|
beat_schedule = {}
|
||||||
|
|
||||||
|
for tenant_id in tenant_ids:
|
||||||
|
for task in tasks_to_schedule:
|
||||||
|
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||||
|
beat_schedule[task_name] = {
|
||||||
|
"task": task["task"],
|
||||||
|
"schedule": task["schedule"],
|
||||||
|
"options": task["options"],
|
||||||
|
"args": (tenant_id,), # Must pass tenant_id as an argument
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
# Include any existing beat schedules
|
||||||
|
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||||
|
beat_schedule.update(existing_beat_schedule)
|
||||||
|
|
||||||
|
# Update the Celery app configuration once
|
||||||
|
celery_app.conf.beat_schedule = beat_schedule
|
||||||
|
@ -107,6 +107,7 @@ class RedisObjectHelper(ABC):
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
redis_client: Redis,
|
redis_client: Redis,
|
||||||
lock: redis.lock.Lock,
|
lock: redis.lock.Lock,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -122,6 +123,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
redis_client: Redis,
|
redis_client: Redis,
|
||||||
lock: redis.lock.Lock,
|
lock: redis.lock.Lock,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
last_lock_time = time.monotonic()
|
last_lock_time = time.monotonic()
|
||||||
|
|
||||||
@ -146,7 +148,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
|||||||
|
|
||||||
result = celery_app.send_task(
|
result = celery_app.send_task(
|
||||||
"vespa_metadata_sync_task",
|
"vespa_metadata_sync_task",
|
||||||
kwargs=dict(document_id=doc.id),
|
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||||
task_id=custom_task_id,
|
task_id=custom_task_id,
|
||||||
priority=DanswerCeleryPriority.LOW,
|
priority=DanswerCeleryPriority.LOW,
|
||||||
@ -168,6 +170,7 @@ class RedisUserGroup(RedisObjectHelper):
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
redis_client: Redis,
|
redis_client: Redis,
|
||||||
lock: redis.lock.Lock,
|
lock: redis.lock.Lock,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
last_lock_time = time.monotonic()
|
last_lock_time = time.monotonic()
|
||||||
|
|
||||||
@ -204,7 +207,7 @@ class RedisUserGroup(RedisObjectHelper):
|
|||||||
|
|
||||||
result = celery_app.send_task(
|
result = celery_app.send_task(
|
||||||
"vespa_metadata_sync_task",
|
"vespa_metadata_sync_task",
|
||||||
kwargs=dict(document_id=doc.id),
|
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||||
task_id=custom_task_id,
|
task_id=custom_task_id,
|
||||||
priority=DanswerCeleryPriority.LOW,
|
priority=DanswerCeleryPriority.LOW,
|
||||||
@ -244,6 +247,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
redis_client: Redis,
|
redis_client: Redis,
|
||||||
lock: redis.lock.Lock,
|
lock: redis.lock.Lock,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
last_lock_time = time.monotonic()
|
last_lock_time = time.monotonic()
|
||||||
|
|
||||||
@ -278,7 +282,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
|||||||
# Priority on sync's triggered by new indexing should be medium
|
# Priority on sync's triggered by new indexing should be medium
|
||||||
result = celery_app.send_task(
|
result = celery_app.send_task(
|
||||||
"vespa_metadata_sync_task",
|
"vespa_metadata_sync_task",
|
||||||
kwargs=dict(document_id=doc.id),
|
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||||
task_id=custom_task_id,
|
task_id=custom_task_id,
|
||||||
priority=DanswerCeleryPriority.MEDIUM,
|
priority=DanswerCeleryPriority.MEDIUM,
|
||||||
@ -300,6 +304,7 @@ class RedisConnectorDeletion(RedisObjectHelper):
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
redis_client: Redis,
|
redis_client: Redis,
|
||||||
lock: redis.lock.Lock,
|
lock: redis.lock.Lock,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
last_lock_time = time.monotonic()
|
last_lock_time = time.monotonic()
|
||||||
|
|
||||||
@ -336,6 +341,7 @@ class RedisConnectorDeletion(RedisObjectHelper):
|
|||||||
document_id=doc.id,
|
document_id=doc.id,
|
||||||
connector_id=cc_pair.connector_id,
|
connector_id=cc_pair.connector_id,
|
||||||
credential_id=cc_pair.credential_id,
|
credential_id=cc_pair.credential_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
),
|
),
|
||||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||||
task_id=custom_task_id,
|
task_id=custom_task_id,
|
||||||
@ -409,6 +415,7 @@ class RedisConnectorPruning(RedisObjectHelper):
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
redis_client: Redis,
|
redis_client: Redis,
|
||||||
lock: redis.lock.Lock | None,
|
lock: redis.lock.Lock | None,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
last_lock_time = time.monotonic()
|
last_lock_time = time.monotonic()
|
||||||
|
|
||||||
@ -442,6 +449,7 @@ class RedisConnectorPruning(RedisObjectHelper):
|
|||||||
document_id=doc_id,
|
document_id=doc_id,
|
||||||
connector_id=cc_pair.connector_id,
|
connector_id=cc_pair.connector_id,
|
||||||
credential_id=cc_pair.credential_id,
|
credential_id=cc_pair.credential_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
),
|
),
|
||||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||||
task_id=custom_task_id,
|
task_id=custom_task_id,
|
||||||
|
@ -23,7 +23,7 @@ from danswer.redis.redis_pool import get_redis_client
|
|||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
trail=False,
|
trail=False,
|
||||||
)
|
)
|
||||||
def check_for_connector_deletion_task() -> None:
|
def check_for_connector_deletion_task(tenant_id: str | None) -> None:
|
||||||
r = get_redis_client()
|
r = get_redis_client()
|
||||||
|
|
||||||
lock_beat = r.lock(
|
lock_beat = r.lock(
|
||||||
@ -40,7 +40,7 @@ def check_for_connector_deletion_task() -> None:
|
|||||||
cc_pairs = get_connector_credential_pairs(db_session)
|
cc_pairs = get_connector_credential_pairs(db_session)
|
||||||
for cc_pair in cc_pairs:
|
for cc_pair in cc_pairs:
|
||||||
try_generate_document_cc_pair_cleanup_tasks(
|
try_generate_document_cc_pair_cleanup_tasks(
|
||||||
cc_pair, db_session, r, lock_beat
|
cc_pair, db_session, r, lock_beat, tenant_id
|
||||||
)
|
)
|
||||||
except SoftTimeLimitExceeded:
|
except SoftTimeLimitExceeded:
|
||||||
task_logger.info(
|
task_logger.info(
|
||||||
@ -58,6 +58,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
r: Redis,
|
r: Redis,
|
||||||
lock_beat: redis.lock.Lock,
|
lock_beat: redis.lock.Lock,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||||
@ -90,7 +91,9 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
|||||||
task_logger.info(
|
task_logger.info(
|
||||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||||
)
|
)
|
||||||
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
|
tasks_generated = rcd.generate_tasks(
|
||||||
|
celery_app, db_session, r, lock_beat, tenant_id
|
||||||
|
)
|
||||||
if tasks_generated is None:
|
if tasks_generated is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -24,17 +24,21 @@ from danswer.connectors.models import InputType
|
|||||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||||
from danswer.db.models import ConnectorCredentialPair
|
from danswer.db.models import ConnectorCredentialPair
|
||||||
from danswer.redis.redis_pool import get_redis_client
|
from danswer.redis.redis_pool import get_redis_client
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
@shared_task(
|
@shared_task(
|
||||||
name="check_for_prune_task_2",
|
name="check_for_prune_task_2",
|
||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
)
|
)
|
||||||
def check_for_prune_task_2() -> None:
|
def check_for_prune_task_2(tenant_id: str | None) -> None:
|
||||||
r = get_redis_client()
|
r = get_redis_client()
|
||||||
|
|
||||||
lock_beat = r.lock(
|
lock_beat = r.lock(
|
||||||
@ -47,11 +51,11 @@ def check_for_prune_task_2() -> None:
|
|||||||
if not lock_beat.acquire(blocking=False):
|
if not lock_beat.acquire(blocking=False):
|
||||||
return
|
return
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
cc_pairs = get_connector_credential_pairs(db_session)
|
cc_pairs = get_connector_credential_pairs(db_session)
|
||||||
for cc_pair in cc_pairs:
|
for cc_pair in cc_pairs:
|
||||||
tasks_created = ccpair_pruning_generator_task_creation_helper(
|
tasks_created = ccpair_pruning_generator_task_creation_helper(
|
||||||
cc_pair, db_session, r, lock_beat
|
cc_pair, db_session, tenant_id, r, lock_beat
|
||||||
)
|
)
|
||||||
if not tasks_created:
|
if not tasks_created:
|
||||||
continue
|
continue
|
||||||
@ -71,6 +75,7 @@ def check_for_prune_task_2() -> None:
|
|||||||
def ccpair_pruning_generator_task_creation_helper(
|
def ccpair_pruning_generator_task_creation_helper(
|
||||||
cc_pair: ConnectorCredentialPair,
|
cc_pair: ConnectorCredentialPair,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
|
tenant_id: str | None,
|
||||||
r: Redis,
|
r: Redis,
|
||||||
lock_beat: redis.lock.Lock,
|
lock_beat: redis.lock.Lock,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
@ -101,13 +106,14 @@ def ccpair_pruning_generator_task_creation_helper(
|
|||||||
if datetime.now(timezone.utc) < next_prune:
|
if datetime.now(timezone.utc) < next_prune:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return try_creating_prune_generator_task(cc_pair, db_session, r)
|
return try_creating_prune_generator_task(cc_pair, db_session, r, tenant_id)
|
||||||
|
|
||||||
|
|
||||||
def try_creating_prune_generator_task(
|
def try_creating_prune_generator_task(
|
||||||
cc_pair: ConnectorCredentialPair,
|
cc_pair: ConnectorCredentialPair,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
r: Redis,
|
r: Redis,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
"""Checks for any conditions that should block the pruning generator task from being
|
"""Checks for any conditions that should block the pruning generator task from being
|
||||||
created, then creates the task.
|
created, then creates the task.
|
||||||
@ -140,7 +146,9 @@ def try_creating_prune_generator_task(
|
|||||||
celery_app.send_task(
|
celery_app.send_task(
|
||||||
"connector_pruning_generator_task",
|
"connector_pruning_generator_task",
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
connector_id=cc_pair.connector_id,
|
||||||
|
credential_id=cc_pair.credential_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
),
|
),
|
||||||
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
|
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
|
||||||
task_id=custom_task_id,
|
task_id=custom_task_id,
|
||||||
@ -153,14 +161,16 @@ def try_creating_prune_generator_task(
|
|||||||
|
|
||||||
|
|
||||||
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
|
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
|
||||||
def connector_pruning_generator_task(connector_id: int, credential_id: int) -> None:
|
def connector_pruning_generator_task(
|
||||||
|
connector_id: int, credential_id: int, tenant_id: str | None
|
||||||
|
) -> None:
|
||||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||||
from the most recently pulled document ID list"""
|
from the most recently pulled document ID list"""
|
||||||
|
|
||||||
r = get_redis_client()
|
r = get_redis_client()
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
try:
|
try:
|
||||||
cc_pair = get_connector_credential_pair(
|
cc_pair = get_connector_credential_pair(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@ -218,7 +228,9 @@ def connector_pruning_generator_task(connector_id: int, credential_id: int) -> N
|
|||||||
task_logger.info(
|
task_logger.info(
|
||||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||||
)
|
)
|
||||||
tasks_generated = rcp.generate_tasks(celery_app, db_session, r, None)
|
tasks_generated = rcp.generate_tasks(
|
||||||
|
celery_app, db_session, r, None, tenant_id
|
||||||
|
)
|
||||||
if tasks_generated is None:
|
if tasks_generated is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from celery import Task
|
from celery import Task
|
||||||
from celery.exceptions import SoftTimeLimitExceeded
|
from celery.exceptions import SoftTimeLimitExceeded
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from danswer.access.access import get_access_for_document
|
from danswer.access.access import get_access_for_document
|
||||||
from danswer.background.celery.celery_app import task_logger
|
from danswer.background.celery.celery_app import task_logger
|
||||||
@ -11,7 +10,7 @@ from danswer.db.document import get_document
|
|||||||
from danswer.db.document import get_document_connector_count
|
from danswer.db.document import get_document_connector_count
|
||||||
from danswer.db.document import mark_document_as_synced
|
from danswer.db.document import mark_document_as_synced
|
||||||
from danswer.db.document_set import fetch_document_sets_for_document
|
from danswer.db.document_set import fetch_document_sets_for_document
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.document_index.document_index_utils import get_both_index_names
|
from danswer.document_index.document_index_utils import get_both_index_names
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.document_index.interfaces import VespaDocumentFields
|
from danswer.document_index.interfaces import VespaDocumentFields
|
||||||
@ -26,7 +25,11 @@ from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
|||||||
max_retries=3,
|
max_retries=3,
|
||||||
)
|
)
|
||||||
def document_by_cc_pair_cleanup_task(
|
def document_by_cc_pair_cleanup_task(
|
||||||
self: Task, document_id: str, connector_id: int, credential_id: int
|
self: Task,
|
||||||
|
document_id: str,
|
||||||
|
connector_id: int,
|
||||||
|
credential_id: int,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||||
Created by connection deletion and connector pruning parent tasks."""
|
Created by connection deletion and connector pruning parent tasks."""
|
||||||
@ -44,7 +47,7 @@ def document_by_cc_pair_cleanup_task(
|
|||||||
(6) delete all relevant entries from postgres
|
(6) delete all relevant entries from postgres
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
action = "skip"
|
action = "skip"
|
||||||
chunks_affected = 0
|
chunks_affected = 0
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ from danswer.db.document_set import fetch_document_sets
|
|||||||
from danswer.db.document_set import fetch_document_sets_for_document
|
from danswer.db.document_set import fetch_document_sets_for_document
|
||||||
from danswer.db.document_set import get_document_set_by_id
|
from danswer.db.document_set import get_document_set_by_id
|
||||||
from danswer.db.document_set import mark_document_set_as_synced
|
from danswer.db.document_set import mark_document_set_as_synced
|
||||||
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.db.index_attempt import delete_index_attempts
|
from danswer.db.index_attempt import delete_index_attempts
|
||||||
from danswer.db.models import DocumentSet
|
from danswer.db.models import DocumentSet
|
||||||
@ -61,7 +62,7 @@ from danswer.utils.variable_functionality import noop_fallback
|
|||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
trail=False,
|
trail=False,
|
||||||
)
|
)
|
||||||
def check_for_vespa_sync_task() -> None:
|
def check_for_vespa_sync_task(tenant_id: str | None) -> None:
|
||||||
"""Runs periodically to check if any document needs syncing.
|
"""Runs periodically to check if any document needs syncing.
|
||||||
Generates sets of tasks for Celery if syncing is needed."""
|
Generates sets of tasks for Celery if syncing is needed."""
|
||||||
|
|
||||||
@ -77,8 +78,8 @@ def check_for_vespa_sync_task() -> None:
|
|||||||
if not lock_beat.acquire(blocking=False):
|
if not lock_beat.acquire(blocking=False):
|
||||||
return
|
return
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
|
try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id)
|
||||||
|
|
||||||
# check if any document sets are not synced
|
# check if any document sets are not synced
|
||||||
document_set_info = fetch_document_sets(
|
document_set_info = fetch_document_sets(
|
||||||
@ -86,7 +87,7 @@ def check_for_vespa_sync_task() -> None:
|
|||||||
)
|
)
|
||||||
for document_set, _ in document_set_info:
|
for document_set, _ in document_set_info:
|
||||||
try_generate_document_set_sync_tasks(
|
try_generate_document_set_sync_tasks(
|
||||||
document_set, db_session, r, lock_beat
|
document_set, db_session, r, lock_beat, tenant_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if any user groups are not synced
|
# check if any user groups are not synced
|
||||||
@ -101,7 +102,7 @@ def check_for_vespa_sync_task() -> None:
|
|||||||
)
|
)
|
||||||
for usergroup in user_groups:
|
for usergroup in user_groups:
|
||||||
try_generate_user_group_sync_tasks(
|
try_generate_user_group_sync_tasks(
|
||||||
usergroup, db_session, r, lock_beat
|
usergroup, db_session, r, lock_beat, tenant_id
|
||||||
)
|
)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
# Always exceptions on the MIT version, which is expected
|
# Always exceptions on the MIT version, which is expected
|
||||||
@ -120,7 +121,7 @@ def check_for_vespa_sync_task() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def try_generate_stale_document_sync_tasks(
|
def try_generate_stale_document_sync_tasks(
|
||||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
# the fence is up, do nothing
|
# the fence is up, do nothing
|
||||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||||
@ -145,7 +146,9 @@ def try_generate_stale_document_sync_tasks(
|
|||||||
cc_pairs = get_connector_credential_pairs(db_session)
|
cc_pairs = get_connector_credential_pairs(db_session)
|
||||||
for cc_pair in cc_pairs:
|
for cc_pair in cc_pairs:
|
||||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||||
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
|
tasks_generated = rc.generate_tasks(
|
||||||
|
celery_app, db_session, r, lock_beat, tenant_id
|
||||||
|
)
|
||||||
|
|
||||||
if tasks_generated is None:
|
if tasks_generated is None:
|
||||||
continue
|
continue
|
||||||
@ -169,7 +172,11 @@ def try_generate_stale_document_sync_tasks(
|
|||||||
|
|
||||||
|
|
||||||
def try_generate_document_set_sync_tasks(
|
def try_generate_document_set_sync_tasks(
|
||||||
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
document_set: DocumentSet,
|
||||||
|
db_session: Session,
|
||||||
|
r: Redis,
|
||||||
|
lock_beat: redis.lock.Lock,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
lock_beat.reacquire()
|
lock_beat.reacquire()
|
||||||
|
|
||||||
@ -193,7 +200,9 @@ def try_generate_document_set_sync_tasks(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add all documents that need to be updated into the queue
|
# Add all documents that need to be updated into the queue
|
||||||
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
|
tasks_generated = rds.generate_tasks(
|
||||||
|
celery_app, db_session, r, lock_beat, tenant_id
|
||||||
|
)
|
||||||
if tasks_generated is None:
|
if tasks_generated is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -214,7 +223,11 @@ def try_generate_document_set_sync_tasks(
|
|||||||
|
|
||||||
|
|
||||||
def try_generate_user_group_sync_tasks(
|
def try_generate_user_group_sync_tasks(
|
||||||
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
usergroup: UserGroup,
|
||||||
|
db_session: Session,
|
||||||
|
r: Redis,
|
||||||
|
lock_beat: redis.lock.Lock,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
lock_beat.reacquire()
|
lock_beat.reacquire()
|
||||||
|
|
||||||
@ -236,7 +249,9 @@ def try_generate_user_group_sync_tasks(
|
|||||||
task_logger.info(
|
task_logger.info(
|
||||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||||
)
|
)
|
||||||
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
|
tasks_generated = rug.generate_tasks(
|
||||||
|
celery_app, db_session, r, lock_beat, tenant_id
|
||||||
|
)
|
||||||
if tasks_generated is None:
|
if tasks_generated is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -471,7 +486,7 @@ def monitor_ccpair_pruning_taskset(
|
|||||||
|
|
||||||
|
|
||||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||||
def monitor_vespa_sync(self: Task) -> None:
|
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
|
||||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||||
It scans for fence values and then gets the counts of any associated tasksets.
|
It scans for fence values and then gets the counts of any associated tasksets.
|
||||||
If the count is 0, that means all tasks finished and we should clean up.
|
If the count is 0, that means all tasks finished and we should clean up.
|
||||||
@ -516,7 +531,7 @@ def monitor_vespa_sync(self: Task) -> None:
|
|||||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||||
monitor_connector_deletion_taskset(key_bytes, r)
|
monitor_connector_deletion_taskset(key_bytes, r)
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
lock_beat.reacquire()
|
lock_beat.reacquire()
|
||||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||||
@ -556,11 +571,13 @@ def monitor_vespa_sync(self: Task) -> None:
|
|||||||
time_limit=60,
|
time_limit=60,
|
||||||
max_retries=3,
|
max_retries=3,
|
||||||
)
|
)
|
||||||
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
|
def vespa_metadata_sync_task(
|
||||||
|
self: Task, document_id: str, tenant_id: str | None
|
||||||
|
) -> bool:
|
||||||
task_logger.info(f"document_id={document_id}")
|
task_logger.info(f"document_id={document_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||||
document_index = get_default_document_index(
|
document_index = get_default_document_index(
|
||||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||||
|
@ -4,6 +4,7 @@ from datetime import datetime
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||||
@ -17,7 +18,7 @@ from danswer.connectors.models import IndexAttemptMetadata
|
|||||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||||
from danswer.db.index_attempt import get_index_attempt
|
from danswer.db.index_attempt import get_index_attempt
|
||||||
from danswer.db.index_attempt import mark_attempt_failed
|
from danswer.db.index_attempt import mark_attempt_failed
|
||||||
@ -46,6 +47,7 @@ def _get_connector_runner(
|
|||||||
attempt: IndexAttempt,
|
attempt: IndexAttempt,
|
||||||
start_time: datetime,
|
start_time: datetime,
|
||||||
end_time: datetime,
|
end_time: datetime,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> ConnectorRunner:
|
) -> ConnectorRunner:
|
||||||
"""
|
"""
|
||||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||||
@ -87,8 +89,7 @@ def _get_connector_runner(
|
|||||||
|
|
||||||
|
|
||||||
def _run_indexing(
|
def _run_indexing(
|
||||||
db_session: Session,
|
db_session: Session, index_attempt: IndexAttempt, tenant_id: str | None
|
||||||
index_attempt: IndexAttempt,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
1. Get documents which are either new or updated from specified application
|
1. Get documents which are either new or updated from specified application
|
||||||
@ -129,6 +130,7 @@ def _run_indexing(
|
|||||||
or (search_settings.status == IndexModelStatus.FUTURE)
|
or (search_settings.status == IndexModelStatus.FUTURE)
|
||||||
),
|
),
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
db_cc_pair = index_attempt.connector_credential_pair
|
db_cc_pair = index_attempt.connector_credential_pair
|
||||||
@ -185,6 +187,7 @@ def _run_indexing(
|
|||||||
attempt=index_attempt,
|
attempt=index_attempt,
|
||||||
start_time=window_start,
|
start_time=window_start,
|
||||||
end_time=window_end,
|
end_time=window_end,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_connector_doc_ids: set[str] = set()
|
all_connector_doc_ids: set[str] = set()
|
||||||
@ -212,7 +215,9 @@ def _run_indexing(
|
|||||||
db_session.refresh(index_attempt)
|
db_session.refresh(index_attempt)
|
||||||
if index_attempt.status != IndexingStatus.IN_PROGRESS:
|
if index_attempt.status != IndexingStatus.IN_PROGRESS:
|
||||||
# Likely due to user manually disabling it or model swap
|
# Likely due to user manually disabling it or model swap
|
||||||
raise RuntimeError("Index Attempt was canceled")
|
raise RuntimeError(
|
||||||
|
f"Index Attempt was canceled, status is {index_attempt.status}"
|
||||||
|
)
|
||||||
|
|
||||||
batch_description = []
|
batch_description = []
|
||||||
for doc in doc_batch:
|
for doc in doc_batch:
|
||||||
@ -373,12 +378,21 @@ def _run_indexing(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
|
def _prepare_index_attempt(
|
||||||
|
db_session: Session, index_attempt_id: int, tenant_id: str | None
|
||||||
|
) -> IndexAttempt:
|
||||||
# make sure that the index attempt can't change in between checking the
|
# make sure that the index attempt can't change in between checking the
|
||||||
# status and marking it as in_progress. This setting will be discarded
|
# status and marking it as in_progress. This setting will be discarded
|
||||||
# after the next commit:
|
# after the next commit:
|
||||||
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
|
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
|
||||||
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
|
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
|
||||||
|
if tenant_id is not None:
|
||||||
|
# Explicitly set the search path for the given tenant
|
||||||
|
db_session.execute(text(f'SET search_path TO "{tenant_id}"'))
|
||||||
|
# Verify the search path was set correctly
|
||||||
|
result = db_session.execute(text("SHOW search_path"))
|
||||||
|
current_search_path = result.scalar()
|
||||||
|
logger.info(f"Current search path set to: {current_search_path}")
|
||||||
|
|
||||||
attempt = get_index_attempt(
|
attempt = get_index_attempt(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@ -401,12 +415,11 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
|||||||
|
|
||||||
|
|
||||||
def run_indexing_entrypoint(
|
def run_indexing_entrypoint(
|
||||||
index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False
|
index_attempt_id: int,
|
||||||
|
tenant_id: str | None,
|
||||||
|
connector_credential_pair_id: int,
|
||||||
|
is_ee: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Entrypoint for indexing run when using dask distributed.
|
|
||||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
|
||||||
and mark the attempt as failed."""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_ee:
|
if is_ee:
|
||||||
global_version.set_ee()
|
global_version.set_ee()
|
||||||
@ -416,26 +429,29 @@ def run_indexing_entrypoint(
|
|||||||
IndexAttemptSingleton.set_cc_and_index_id(
|
IndexAttemptSingleton.set_cc_and_index_id(
|
||||||
index_attempt_id, connector_credential_pair_id
|
index_attempt_id, connector_credential_pair_id
|
||||||
)
|
)
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
attempt = _prepare_index_attempt(db_session, index_attempt_id, tenant_id)
|
||||||
# make sure that it is valid to run this indexing attempt + mark it
|
|
||||||
# as in progress
|
|
||||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Indexing starting: "
|
f"Indexing starting for tenant {tenant_id}: "
|
||||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
if tenant_id is not None
|
||||||
|
else ""
|
||||||
|
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
_run_indexing(db_session, attempt)
|
_run_indexing(db_session, attempt, tenant_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Indexing finished: "
|
f"Indexing finished for tenant {tenant_id}: "
|
||||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
if tenant_id is not None
|
||||||
|
else ""
|
||||||
|
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
logger.exception(
|
||||||
|
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
|
||||||
|
)
|
||||||
|
@ -6,6 +6,8 @@ import dask
|
|||||||
from dask.distributed import Client
|
from dask.distributed import Client
|
||||||
from dask.distributed import Future
|
from dask.distributed import Future
|
||||||
from distributed import LocalCluster
|
from distributed import LocalCluster
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.exc import ProgrammingError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.background.indexing.dask_utils import ResourceLogger
|
from danswer.background.indexing.dask_utils import ResourceLogger
|
||||||
@ -15,14 +17,16 @@ from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
|||||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||||
|
from danswer.configs.constants import TENANT_ID_PREFIX
|
||||||
from danswer.db.connector import fetch_connectors
|
from danswer.db.connector import fetch_connectors
|
||||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||||
from danswer.db.engine import get_db_current_time
|
from danswer.db.engine import get_db_current_time
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.engine import SqlEngine
|
from danswer.db.engine import SqlEngine
|
||||||
from danswer.db.index_attempt import create_index_attempt
|
from danswer.db.index_attempt import create_index_attempt
|
||||||
from danswer.db.index_attempt import get_index_attempt
|
from danswer.db.index_attempt import get_index_attempt
|
||||||
@ -153,13 +157,15 @@ def _mark_run_failed(
|
|||||||
"""Main funcs"""
|
"""Main funcs"""
|
||||||
|
|
||||||
|
|
||||||
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
def create_indexing_jobs(
|
||||||
|
existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None
|
||||||
|
) -> None:
|
||||||
"""Creates new indexing jobs for each connector / credential pair which is:
|
"""Creates new indexing jobs for each connector / credential pair which is:
|
||||||
1. Enabled
|
1. Enabled
|
||||||
2. `refresh_frequency` time has passed since the last indexing run for this pair
|
2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||||
3. There is not already an ongoing indexing attempt for this pair
|
3. There is not already an ongoing indexing attempt for this pair
|
||||||
"""
|
"""
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
ongoing: set[tuple[int | None, int]] = set()
|
ongoing: set[tuple[int | None, int]] = set()
|
||||||
for attempt_id in existing_jobs:
|
for attempt_id in existing_jobs:
|
||||||
attempt = get_index_attempt(
|
attempt = get_index_attempt(
|
||||||
@ -214,11 +220,12 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
|||||||
|
|
||||||
def cleanup_indexing_jobs(
|
def cleanup_indexing_jobs(
|
||||||
existing_jobs: dict[int, Future | SimpleJob],
|
existing_jobs: dict[int, Future | SimpleJob],
|
||||||
|
tenant_id: str | None,
|
||||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||||
) -> dict[int, Future | SimpleJob]:
|
) -> dict[int, Future | SimpleJob]:
|
||||||
existing_jobs_copy = existing_jobs.copy()
|
existing_jobs_copy = existing_jobs.copy()
|
||||||
# clean up completed jobs
|
# clean up completed jobs
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
for attempt_id, job in existing_jobs.items():
|
for attempt_id, job in existing_jobs.items():
|
||||||
index_attempt = get_index_attempt(
|
index_attempt = get_index_attempt(
|
||||||
db_session=db_session, index_attempt_id=attempt_id
|
db_session=db_session, index_attempt_id=attempt_id
|
||||||
@ -256,11 +263,13 @@ def cleanup_indexing_jobs(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# clean up in-progress jobs that were never completed
|
# clean up in-progress jobs that were never completed
|
||||||
|
try:
|
||||||
connectors = fetch_connectors(db_session)
|
connectors = fetch_connectors(db_session)
|
||||||
for connector in connectors:
|
for connector in connectors:
|
||||||
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||||
connector.id, db_session
|
connector.id, db_session
|
||||||
)
|
)
|
||||||
|
|
||||||
for index_attempt in in_progress_indexing_attempts:
|
for index_attempt in in_progress_indexing_attempts:
|
||||||
if index_attempt.id in existing_jobs:
|
if index_attempt.id in existing_jobs:
|
||||||
# If index attempt is canceled, stop the run
|
# If index attempt is canceled, stop the run
|
||||||
@ -287,7 +296,8 @@ def cleanup_indexing_jobs(
|
|||||||
index_attempt=index_attempt,
|
index_attempt=index_attempt,
|
||||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||||
)
|
)
|
||||||
|
except ProgrammingError:
|
||||||
|
logger.debug(f"No Connector Table exists for: {tenant_id}")
|
||||||
return existing_jobs_copy
|
return existing_jobs_copy
|
||||||
|
|
||||||
|
|
||||||
@ -295,13 +305,15 @@ def kickoff_indexing_jobs(
|
|||||||
existing_jobs: dict[int, Future | SimpleJob],
|
existing_jobs: dict[int, Future | SimpleJob],
|
||||||
client: Client | SimpleJobClient,
|
client: Client | SimpleJobClient,
|
||||||
secondary_client: Client | SimpleJobClient,
|
secondary_client: Client | SimpleJobClient,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> dict[int, Future | SimpleJob]:
|
) -> dict[int, Future | SimpleJob]:
|
||||||
existing_jobs_copy = existing_jobs.copy()
|
existing_jobs_copy = existing_jobs.copy()
|
||||||
engine = get_sqlalchemy_engine()
|
|
||||||
|
current_session = get_session_with_tenant(tenant_id)
|
||||||
|
|
||||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||||
with Session(engine) as db_session:
|
with current_session as db_session:
|
||||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||||
new_indexing_attempts = [
|
new_indexing_attempts = [
|
||||||
@ -332,7 +344,7 @@ def kickoff_indexing_jobs(
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||||
)
|
)
|
||||||
with Session(engine) as db_session:
|
with current_session as db_session:
|
||||||
mark_attempt_failed(
|
mark_attempt_failed(
|
||||||
attempt, db_session, failure_reason="Connector is null"
|
attempt, db_session, failure_reason="Connector is null"
|
||||||
)
|
)
|
||||||
@ -341,7 +353,7 @@ def kickoff_indexing_jobs(
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||||
)
|
)
|
||||||
with Session(engine) as db_session:
|
with current_session as db_session:
|
||||||
mark_attempt_failed(
|
mark_attempt_failed(
|
||||||
attempt, db_session, failure_reason="Credential is null"
|
attempt, db_session, failure_reason="Credential is null"
|
||||||
)
|
)
|
||||||
@ -352,6 +364,7 @@ def kickoff_indexing_jobs(
|
|||||||
run = client.submit(
|
run = client.submit(
|
||||||
run_indexing_entrypoint,
|
run_indexing_entrypoint,
|
||||||
attempt.id,
|
attempt.id,
|
||||||
|
tenant_id,
|
||||||
attempt.connector_credential_pair_id,
|
attempt.connector_credential_pair_id,
|
||||||
global_version.is_ee_version(),
|
global_version.is_ee_version(),
|
||||||
pure=False,
|
pure=False,
|
||||||
@ -363,6 +376,7 @@ def kickoff_indexing_jobs(
|
|||||||
run = secondary_client.submit(
|
run = secondary_client.submit(
|
||||||
run_indexing_entrypoint,
|
run_indexing_entrypoint,
|
||||||
attempt.id,
|
attempt.id,
|
||||||
|
tenant_id,
|
||||||
attempt.connector_credential_pair_id,
|
attempt.connector_credential_pair_id,
|
||||||
global_version.is_ee_version(),
|
global_version.is_ee_version(),
|
||||||
pure=False,
|
pure=False,
|
||||||
@ -398,42 +412,40 @@ def kickoff_indexing_jobs(
|
|||||||
return existing_jobs_copy
|
return existing_jobs_copy
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||||
|
if not MULTI_TENANT:
|
||||||
|
return [None]
|
||||||
|
with get_session_with_tenant(tenant_id="public") as session:
|
||||||
|
result = session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT schema_name
|
||||||
|
FROM information_schema.schemata
|
||||||
|
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tenant_ids = [row[0] for row in result]
|
||||||
|
|
||||||
|
valid_tenants = [
|
||||||
|
tenant
|
||||||
|
for tenant in tenant_ids
|
||||||
|
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||||
|
]
|
||||||
|
|
||||||
|
return valid_tenants
|
||||||
|
|
||||||
|
|
||||||
def update_loop(
|
def update_loop(
|
||||||
delay: int = 10,
|
delay: int = 10,
|
||||||
num_workers: int = NUM_INDEXING_WORKERS,
|
num_workers: int = NUM_INDEXING_WORKERS,
|
||||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||||
) -> None:
|
) -> None:
|
||||||
engine = get_sqlalchemy_engine()
|
|
||||||
with Session(engine) as db_session:
|
|
||||||
check_index_swap(db_session=db_session)
|
|
||||||
search_settings = get_current_search_settings(db_session)
|
|
||||||
|
|
||||||
# So that the first time users aren't surprised by really slow speed of first
|
|
||||||
# batch of documents indexed
|
|
||||||
|
|
||||||
if search_settings.provider_type is None:
|
|
||||||
logger.notice("Running a first inference to warm up embedding model")
|
|
||||||
embedding_model = EmbeddingModel.from_db_model(
|
|
||||||
search_settings=search_settings,
|
|
||||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
|
||||||
server_port=MODEL_SERVER_PORT,
|
|
||||||
)
|
|
||||||
|
|
||||||
warm_up_bi_encoder(
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
)
|
|
||||||
logger.notice("First inference complete.")
|
|
||||||
|
|
||||||
client_primary: Client | SimpleJobClient
|
client_primary: Client | SimpleJobClient
|
||||||
client_secondary: Client | SimpleJobClient
|
client_secondary: Client | SimpleJobClient
|
||||||
if DASK_JOB_CLIENT_ENABLED:
|
if DASK_JOB_CLIENT_ENABLED:
|
||||||
cluster_primary = LocalCluster(
|
cluster_primary = LocalCluster(
|
||||||
n_workers=num_workers,
|
n_workers=num_workers,
|
||||||
threads_per_worker=1,
|
threads_per_worker=1,
|
||||||
# there are warning about high memory usage + "Event loop unresponsive"
|
|
||||||
# which are not relevant to us since our workers are expected to use a
|
|
||||||
# lot of memory + involve CPU intensive tasks that will not relinquish
|
|
||||||
# the event loop
|
|
||||||
silence_logs=logging.ERROR,
|
silence_logs=logging.ERROR,
|
||||||
)
|
)
|
||||||
cluster_secondary = LocalCluster(
|
cluster_secondary = LocalCluster(
|
||||||
@ -449,7 +461,7 @@ def update_loop(
|
|||||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||||
|
|
||||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
|
||||||
|
|
||||||
logger.notice("Startup complete. Waiting for indexing jobs...")
|
logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||||
while True:
|
while True:
|
||||||
@ -458,24 +470,58 @@ def update_loop(
|
|||||||
logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||||
|
|
||||||
if existing_jobs:
|
if existing_jobs:
|
||||||
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Found existing indexing jobs: "
|
"Found existing indexing jobs: "
|
||||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
tenants = get_all_tenant_ids()
|
||||||
check_index_swap(db_session)
|
|
||||||
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
|
for tenant_id in tenants:
|
||||||
create_indexing_jobs(existing_jobs=existing_jobs)
|
try:
|
||||||
existing_jobs = kickoff_indexing_jobs(
|
logger.debug(
|
||||||
existing_jobs=existing_jobs,
|
f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}"
|
||||||
|
)
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
|
check_index_swap(db_session=db_session)
|
||||||
|
if not MULTI_TENANT:
|
||||||
|
search_settings = get_current_search_settings(db_session)
|
||||||
|
if search_settings.provider_type is None:
|
||||||
|
logger.notice(
|
||||||
|
"Running a first inference to warm up embedding model"
|
||||||
|
)
|
||||||
|
embedding_model = EmbeddingModel.from_db_model(
|
||||||
|
search_settings=search_settings,
|
||||||
|
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||||
|
server_port=MODEL_SERVER_PORT,
|
||||||
|
)
|
||||||
|
warm_up_bi_encoder(embedding_model=embedding_model)
|
||||||
|
logger.notice("First inference complete.")
|
||||||
|
|
||||||
|
tenant_jobs = existing_jobs.get(tenant_id, {})
|
||||||
|
|
||||||
|
tenant_jobs = cleanup_indexing_jobs(
|
||||||
|
existing_jobs=tenant_jobs, tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id)
|
||||||
|
tenant_jobs = kickoff_indexing_jobs(
|
||||||
|
existing_jobs=tenant_jobs,
|
||||||
client=client_primary,
|
client=client_primary,
|
||||||
secondary_client=client_secondary,
|
secondary_client=client_secondary,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
existing_jobs[tenant_id] = tenant_jobs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to process tenant {tenant_id or 'default'}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Failed to run update due to {e}")
|
logger.exception(f"Failed to run update due to {e}")
|
||||||
|
|
||||||
sleep_time = delay - (time.time() - start)
|
sleep_time = delay - (time.time() - start)
|
||||||
if sleep_time > 0:
|
if sleep_time > 0:
|
||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
|
@ -429,3 +429,5 @@ SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
|
|||||||
|
|
||||||
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
|
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
|
||||||
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
|
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
|
||||||
|
|
||||||
|
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
|
||||||
|
@ -31,6 +31,9 @@ DISABLED_GEN_AI_MSG = (
|
|||||||
"You can still use Danswer as a search engine."
|
"You can still use Danswer as a search engine."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prefix used for all tenant ids
|
||||||
|
TENANT_ID_PREFIX = "tenant_"
|
||||||
|
|
||||||
# Postgres connection constants for application_name
|
# Postgres connection constants for application_name
|
||||||
POSTGRES_WEB_APP_NAME = "web"
|
POSTGRES_WEB_APP_NAME = "web"
|
||||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||||
|
@ -13,7 +13,7 @@ from sqlalchemy.future import select
|
|||||||
|
|
||||||
from danswer.auth.schemas import UserRole
|
from danswer.auth.schemas import UserRole
|
||||||
from danswer.db.engine import get_async_session
|
from danswer.db.engine import get_async_session
|
||||||
from danswer.db.engine import get_sqlalchemy_async_engine
|
from danswer.db.engine import get_async_session_with_tenant
|
||||||
from danswer.db.models import AccessToken
|
from danswer.db.models import AccessToken
|
||||||
from danswer.db.models import OAuthAccount
|
from danswer.db.models import OAuthAccount
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
@ -34,7 +34,7 @@ def get_default_admin_user_emails() -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
async def get_user_count() -> int:
|
async def get_user_count() -> int:
|
||||||
async with AsyncSession(get_sqlalchemy_async_engine()) as asession:
|
async with get_async_session_with_tenant() as asession:
|
||||||
stmt = select(func.count(User.id))
|
stmt = select(func.count(User.id))
|
||||||
result = await asession.execute(stmt)
|
result = await asession.execute(stmt)
|
||||||
user_count = result.scalar()
|
user_count = result.scalar()
|
||||||
|
@ -390,6 +390,7 @@ def add_credential_to_connector(
|
|||||||
)
|
)
|
||||||
db_session.add(association)
|
db_session.add(association)
|
||||||
db_session.flush() # make sure the association has an id
|
db_session.flush() # make sure the association has an id
|
||||||
|
db_session.refresh(association)
|
||||||
|
|
||||||
if groups and access_type != AccessType.SYNC:
|
if groups and access_type != AccessType.SYNC:
|
||||||
_relate_groups_to_cc_pair__no_commit(
|
_relate_groups_to_cc_pair__no_commit(
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import contextvars
|
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import ContextManager
|
from typing import ContextManager
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import Depends
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlalchemy import event
|
from sqlalchemy import event
|
||||||
@ -39,7 +39,7 @@ from danswer.configs.app_configs import SECRET_JWT_KEY
|
|||||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
from shared_configs.configs import current_tenant_id
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@ -230,18 +230,8 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
|||||||
return _ASYNC_ENGINE
|
return _ASYNC_ENGINE
|
||||||
|
|
||||||
|
|
||||||
# Context variable to store the current tenant ID
|
# Dependency to get the current tenant ID
|
||||||
# This allows us to maintain tenant-specific context throughout the request lifecycle
|
# If no token is present, uses the default schema for this use case
|
||||||
# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups
|
|
||||||
# This context variable works in both synchronous and asynchronous contexts
|
|
||||||
# In async code, it's automatically carried across coroutines
|
|
||||||
# In sync code, it's managed per thread
|
|
||||||
current_tenant_id = contextvars.ContextVar(
|
|
||||||
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Dependency to get the current tenant ID and set the context variable
|
|
||||||
def get_current_tenant_id(request: Request) -> str:
|
def get_current_tenant_id(request: Request) -> str:
|
||||||
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
||||||
if not MULTI_TENANT:
|
if not MULTI_TENANT:
|
||||||
@ -251,32 +241,31 @@ def get_current_tenant_id(request: Request) -> str:
|
|||||||
|
|
||||||
token = request.cookies.get("tenant_details")
|
token = request.cookies.get("tenant_details")
|
||||||
if not token:
|
if not token:
|
||||||
|
current_value = current_tenant_id.get()
|
||||||
# If no token is present, use the default schema or handle accordingly
|
# If no token is present, use the default schema or handle accordingly
|
||||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
return current_value
|
||||||
current_tenant_id.set(tenant_id)
|
|
||||||
return tenant_id
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
||||||
tenant_id = payload.get("tenant_id")
|
tenant_id = payload.get("tenant_id")
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
raise HTTPException(
|
return current_tenant_id.get()
|
||||||
status_code=400, detail="Invalid token: tenant_id missing"
|
|
||||||
)
|
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
raise ValueError("Invalid tenant ID format")
|
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||||
current_tenant_id.set(tenant_id)
|
current_tenant_id.set(tenant_id)
|
||||||
|
|
||||||
return tenant_id
|
return tenant_id
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
raise HTTPException(status_code=401, detail="Invalid token format")
|
return current_tenant_id.get()
|
||||||
except ValueError as e:
|
except Exception as e:
|
||||||
# Let the 400 error bubble up
|
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
def get_session_with_tenant(tenant_id: str | None = None) -> Session:
|
@asynccontextmanager
|
||||||
|
async def get_async_session_with_tenant(
|
||||||
|
tenant_id: str | None = None,
|
||||||
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
if tenant_id is None:
|
if tenant_id is None:
|
||||||
tenant_id = current_tenant_id.get()
|
tenant_id = current_tenant_id.get()
|
||||||
|
|
||||||
@ -284,20 +273,78 @@ def get_session_with_tenant(tenant_id: str | None = None) -> Session:
|
|||||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||||
raise Exception("Invalid tenant ID")
|
raise Exception("Invalid tenant ID")
|
||||||
|
|
||||||
engine = SqlEngine.get_engine()
|
engine = get_sqlalchemy_async_engine()
|
||||||
session = Session(engine, expire_on_commit=False)
|
async_session_factory = sessionmaker(
|
||||||
|
bind=engine, expire_on_commit=False, class_=AsyncSession
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
@event.listens_for(session, "after_begin")
|
async with async_session_factory() as session:
|
||||||
def set_search_path(session: Session, transaction: Any, connection: Any) -> None:
|
try:
|
||||||
connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id})
|
# Set the search_path to the tenant's schema
|
||||||
|
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||||
return session
|
except Exception as e:
|
||||||
|
logger.error(f"Error setting search_path: {str(e)}")
|
||||||
|
# You can choose to re-raise the exception or handle it
|
||||||
|
# Here, we'll re-raise to prevent proceeding with an incorrect session
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
def get_session(
|
@contextmanager
|
||||||
tenant_id: str = Depends(get_current_tenant_id),
|
def get_session_with_tenant(
|
||||||
|
tenant_id: str | None = None,
|
||||||
) -> Generator[Session, None, None]:
|
) -> Generator[Session, None, None]:
|
||||||
"""Generate a database session with the appropriate tenant schema set."""
|
"""Generate a database session with the appropriate tenant schema set."""
|
||||||
|
engine = get_sqlalchemy_engine()
|
||||||
|
if tenant_id is None:
|
||||||
|
tenant_id = current_tenant_id.get()
|
||||||
|
|
||||||
|
if not is_valid_schema_name(tenant_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||||
|
|
||||||
|
# Establish a raw connection without starting a transaction
|
||||||
|
with engine.connect() as connection:
|
||||||
|
# Access the raw DBAPI connection
|
||||||
|
dbapi_connection = connection.connection
|
||||||
|
|
||||||
|
# Execute SET search_path outside of any transaction
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
try:
|
||||||
|
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||||
|
# Optionally verify the search_path was set correctly
|
||||||
|
cursor.execute("SHOW search_path")
|
||||||
|
cursor.fetchone()
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
# Proceed to create a session using the connection
|
||||||
|
with Session(bind=connection, expire_on_commit=False) as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
# Reset search_path to default after the session is used
|
||||||
|
if MULTI_TENANT:
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
try:
|
||||||
|
cursor.execute('SET search_path TO "$user", public')
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_generator_with_tenant(
|
||||||
|
tenant_id: str | None = None,
|
||||||
|
) -> Generator[Session, None, None]:
|
||||||
|
with get_session_with_tenant(tenant_id) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
def get_session() -> Generator[Session, None, None]:
|
||||||
|
"""Generate a database session with the appropriate tenant schema set."""
|
||||||
|
tenant_id = current_tenant_id.get()
|
||||||
|
if tenant_id == "public" and MULTI_TENANT:
|
||||||
|
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||||
|
|
||||||
engine = get_sqlalchemy_engine()
|
engine = get_sqlalchemy_engine()
|
||||||
with Session(engine, expire_on_commit=False) as session:
|
with Session(engine, expire_on_commit=False) as session:
|
||||||
if MULTI_TENANT:
|
if MULTI_TENANT:
|
||||||
@ -308,10 +355,9 @@ def get_session(
|
|||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
async def get_async_session(
|
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
tenant_id: str = Depends(get_current_tenant_id),
|
|
||||||
) -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
"""Generate an async database session with the appropriate tenant schema set."""
|
"""Generate an async database session with the appropriate tenant schema set."""
|
||||||
|
tenant_id = current_tenant_id.get()
|
||||||
engine = get_sqlalchemy_async_engine()
|
engine = get_sqlalchemy_async_engine()
|
||||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||||
if MULTI_TENANT:
|
if MULTI_TENANT:
|
||||||
@ -324,7 +370,7 @@ async def get_async_session(
|
|||||||
|
|
||||||
def get_session_context_manager() -> ContextManager[Session]:
|
def get_session_context_manager() -> ContextManager[Session]:
|
||||||
"""Context manager for database sessions."""
|
"""Context manager for database sessions."""
|
||||||
return contextlib.contextmanager(get_session)()
|
return contextlib.contextmanager(get_session_generator_with_tenant)()
|
||||||
|
|
||||||
|
|
||||||
def get_session_factory() -> sessionmaker[Session]:
|
def get_session_factory() -> sessionmaker[Session]:
|
||||||
|
@ -1763,3 +1763,23 @@ class UsageReport(Base):
|
|||||||
|
|
||||||
requestor = relationship("User")
|
requestor = relationship("User")
|
||||||
file = relationship("PGFileStore")
|
file = relationship("PGFileStore")
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Multi-tenancy related tables
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class PublicBase(DeclarativeBase):
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
|
||||||
|
class UserTenantMapping(Base):
|
||||||
|
__tablename__ = "user_tenant_mapping"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
|
||||||
|
{"schema": "public"},
|
||||||
|
)
|
||||||
|
|
||||||
|
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
|
||||||
|
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
|
||||||
|
@ -137,6 +137,7 @@ def index_doc_batch_with_handler(
|
|||||||
attempt_id: int | None,
|
attempt_id: int | None,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
ignore_time_skip: bool = False,
|
ignore_time_skip: bool = False,
|
||||||
|
tenant_id: str | None = None,
|
||||||
) -> tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
r = (0, 0)
|
r = (0, 0)
|
||||||
try:
|
try:
|
||||||
@ -148,6 +149,7 @@ def index_doc_batch_with_handler(
|
|||||||
index_attempt_metadata=index_attempt_metadata,
|
index_attempt_metadata=index_attempt_metadata,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
ignore_time_skip=ignore_time_skip,
|
ignore_time_skip=ignore_time_skip,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if INDEXING_EXCEPTION_LIMIT == 0:
|
if INDEXING_EXCEPTION_LIMIT == 0:
|
||||||
@ -261,6 +263,7 @@ def index_doc_batch(
|
|||||||
index_attempt_metadata: IndexAttemptMetadata,
|
index_attempt_metadata: IndexAttemptMetadata,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
ignore_time_skip: bool = False,
|
ignore_time_skip: bool = False,
|
||||||
|
tenant_id: str | None = None,
|
||||||
) -> tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||||
Note that the documents should already be batched at this point so that it does not inflate the
|
Note that the documents should already be batched at this point so that it does not inflate the
|
||||||
@ -324,6 +327,7 @@ def index_doc_batch(
|
|||||||
if chunk.source_document.id in ctx.id_to_db_doc_map
|
if chunk.source_document.id in ctx.id_to_db_doc_map
|
||||||
else DEFAULT_BOOST
|
else DEFAULT_BOOST
|
||||||
),
|
),
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
for chunk in chunks_with_embeddings
|
for chunk in chunks_with_embeddings
|
||||||
]
|
]
|
||||||
@ -373,6 +377,7 @@ def build_indexing_pipeline(
|
|||||||
chunker: Chunker | None = None,
|
chunker: Chunker | None = None,
|
||||||
ignore_time_skip: bool = False,
|
ignore_time_skip: bool = False,
|
||||||
attempt_id: int | None = None,
|
attempt_id: int | None = None,
|
||||||
|
tenant_id: str | None = None,
|
||||||
) -> IndexingPipelineProtocol:
|
) -> IndexingPipelineProtocol:
|
||||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||||
search_settings = get_current_search_settings(db_session)
|
search_settings = get_current_search_settings(db_session)
|
||||||
@ -416,4 +421,5 @@ def build_indexing_pipeline(
|
|||||||
ignore_time_skip=ignore_time_skip,
|
ignore_time_skip=ignore_time_skip,
|
||||||
attempt_id=attempt_id,
|
attempt_id=attempt_id,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
@ -75,6 +75,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
|||||||
negative -> ranked lower.
|
negative -> ranked lower.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
tenant_id: str | None = None
|
||||||
access: "DocumentAccess"
|
access: "DocumentAccess"
|
||||||
document_sets: set[str]
|
document_sets: set[str]
|
||||||
boost: int
|
boost: int
|
||||||
@ -86,6 +87,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
|||||||
access: "DocumentAccess",
|
access: "DocumentAccess",
|
||||||
document_sets: set[str],
|
document_sets: set[str],
|
||||||
boost: int,
|
boost: int,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> "DocMetadataAwareIndexChunk":
|
) -> "DocMetadataAwareIndexChunk":
|
||||||
index_chunk_data = index_chunk.model_dump()
|
index_chunk_data = index_chunk.model_dump()
|
||||||
return cls(
|
return cls(
|
||||||
@ -93,6 +95,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
|||||||
access=access,
|
access=access,
|
||||||
document_sets=document_sets,
|
document_sets=document_sets,
|
||||||
boost=boost,
|
boost=boost,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,15 +3,21 @@ from collections.abc import Iterator
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy import text
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.db.engine import is_valid_schema_name
|
||||||
from danswer.db.models import KVStore
|
from danswer.db.models import KVStore
|
||||||
from danswer.key_value_store.interface import JSON_ro
|
from danswer.key_value_store.interface import JSON_ro
|
||||||
from danswer.key_value_store.interface import KeyValueStore
|
from danswer.key_value_store.interface import KeyValueStore
|
||||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||||
from danswer.redis.redis_pool import get_redis_client
|
from danswer.redis.redis_pool import get_redis_client
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
from shared_configs.configs import current_tenant_id
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@ -28,6 +34,16 @@ class PgRedisKVStore(KeyValueStore):
|
|||||||
def get_session(self) -> Iterator[Session]:
|
def get_session(self) -> Iterator[Session]:
|
||||||
engine = get_sqlalchemy_engine()
|
engine = get_sqlalchemy_engine()
|
||||||
with Session(engine, expire_on_commit=False) as session:
|
with Session(engine, expire_on_commit=False) as session:
|
||||||
|
if MULTI_TENANT:
|
||||||
|
tenant_id = current_tenant_id.get()
|
||||||
|
if tenant_id == "public":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401, detail="User must authenticate"
|
||||||
|
)
|
||||||
|
if not is_valid_schema_name(tenant_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||||
|
# Set the search_path to the tenant's schema
|
||||||
|
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||||
|
@ -29,6 +29,7 @@ from danswer.configs.app_configs import APP_PORT
|
|||||||
from danswer.configs.app_configs import AUTH_TYPE
|
from danswer.configs.app_configs import AUTH_TYPE
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
|
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||||
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||||
@ -157,6 +158,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
verify_auth = fetch_versioned_implementation(
|
verify_auth = fetch_versioned_implementation(
|
||||||
"danswer.auth.users", "verify_auth_setting"
|
"danswer.auth.users", "verify_auth_setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Will throw exception if an issue is found
|
# Will throw exception if an issue is found
|
||||||
verify_auth()
|
verify_auth()
|
||||||
|
|
||||||
@ -169,9 +171,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
# fill up Postgres connection pools
|
# fill up Postgres connection pools
|
||||||
await warm_up_connections()
|
await warm_up_connections()
|
||||||
|
|
||||||
|
if not MULTI_TENANT:
|
||||||
# We cache this at the beginning so there is no delay in the first telemetry
|
# We cache this at the beginning so there is no delay in the first telemetry
|
||||||
get_or_generate_uuid()
|
get_or_generate_uuid()
|
||||||
|
|
||||||
|
# If we are multi-tenant, we need to only set up initial public tables
|
||||||
with Session(engine) as db_session:
|
with Session(engine) as db_session:
|
||||||
setup_danswer(db_session)
|
setup_danswer(db_session)
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ from danswer.db.connector_credential_pair import (
|
|||||||
update_connector_credential_pair_from_id,
|
update_connector_credential_pair_from_id,
|
||||||
)
|
)
|
||||||
from danswer.db.document import get_document_counts_for_cc_pairs
|
from danswer.db.document import get_document_counts_for_cc_pairs
|
||||||
|
from danswer.db.engine import current_tenant_id
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.db.enums import AccessType
|
from danswer.db.enums import AccessType
|
||||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||||
@ -257,7 +258,9 @@ def prune_cc_pair(
|
|||||||
f"credential_id={cc_pair.credential_id} "
|
f"credential_id={cc_pair.credential_id} "
|
||||||
f"{cc_pair.connector.name} connector."
|
f"{cc_pair.connector.name} connector."
|
||||||
)
|
)
|
||||||
tasks_created = try_creating_prune_generator_task(cc_pair, db_session, r)
|
tasks_created = try_creating_prune_generator_task(
|
||||||
|
cc_pair, db_session, r, current_tenant_id.get()
|
||||||
|
)
|
||||||
if not tasks_created:
|
if not tasks_created:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
@ -342,7 +345,7 @@ def sync_cc_pair(
|
|||||||
|
|
||||||
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
|
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
|
||||||
sync_external_doc_permissions_task.apply_async(
|
sync_external_doc_permissions_task.apply_async(
|
||||||
kwargs=dict(cc_pair_id=cc_pair_id),
|
kwargs=dict(cc_pair_id=cc_pair_id, tenant_id=current_tenant_id.get()),
|
||||||
)
|
)
|
||||||
|
|
||||||
return StatusResponse(
|
return StatusResponse(
|
||||||
|
@ -20,6 +20,7 @@ from danswer.db.connector_credential_pair import (
|
|||||||
update_connector_credential_pair_from_id,
|
update_connector_credential_pair_from_id,
|
||||||
)
|
)
|
||||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||||
|
from danswer.db.engine import get_current_tenant_id
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||||
from danswer.db.feedback import fetch_docs_ranked_by_boost
|
from danswer.db.feedback import fetch_docs_ranked_by_boost
|
||||||
@ -146,6 +147,7 @@ def create_deletion_attempt_for_connector_id(
|
|||||||
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier,
|
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier,
|
||||||
user: User = Depends(current_curator_or_admin_user),
|
user: User = Depends(current_curator_or_admin_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
|
tenant_id: str = Depends(get_current_tenant_id),
|
||||||
) -> None:
|
) -> None:
|
||||||
connector_id = connector_credential_pair_identifier.connector_id
|
connector_id = connector_credential_pair_identifier.connector_id
|
||||||
credential_id = connector_credential_pair_identifier.credential_id
|
credential_id = connector_credential_pair_identifier.credential_id
|
||||||
@ -196,6 +198,7 @@ def create_deletion_attempt_for_connector_id(
|
|||||||
celery_app.send_task(
|
celery_app.send_task(
|
||||||
"check_for_connector_deletion_task",
|
"check_for_connector_deletion_task",
|
||||||
priority=DanswerCeleryPriority.HIGH,
|
priority=DanswerCeleryPriority.HIGH,
|
||||||
|
kwargs={"tenant_id": tenant_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
if cc_pair.connector.source == DocumentSource.FILE:
|
if cc_pair.connector.source == DocumentSource.FILE:
|
||||||
|
@ -2,17 +2,21 @@ import re
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
|
|
||||||
|
import jwt
|
||||||
from email_validator import validate_email
|
from email_validator import validate_email
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from fastapi import Request
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
|
from psycopg2.errors import UniqueViolation
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import Column
|
from sqlalchemy import Column
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import desc
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.auth.invited_users import get_invited_users
|
from danswer.auth.invited_users import get_invited_users
|
||||||
@ -26,9 +30,12 @@ from danswer.auth.users import current_curator_or_admin_user
|
|||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
from danswer.auth.users import optional_user
|
from danswer.auth.users import optional_user
|
||||||
from danswer.configs.app_configs import AUTH_TYPE
|
from danswer.configs.app_configs import AUTH_TYPE
|
||||||
|
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||||
from danswer.configs.constants import AuthType
|
from danswer.configs.constants import AuthType
|
||||||
|
from danswer.db.engine import current_tenant_id
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.db.models import AccessToken
|
from danswer.db.models import AccessToken
|
||||||
from danswer.db.models import DocumentSet__User
|
from danswer.db.models import DocumentSet__User
|
||||||
@ -48,10 +55,13 @@ from danswer.server.manage.models import UserRoleUpdateRequest
|
|||||||
from danswer.server.models import FullUserSnapshot
|
from danswer.server.models import FullUserSnapshot
|
||||||
from danswer.server.models import InvitedUserSnapshot
|
from danswer.server.models import InvitedUserSnapshot
|
||||||
from danswer.server.models import MinimalUserSnapshot
|
from danswer.server.models import MinimalUserSnapshot
|
||||||
|
from danswer.server.utils import send_user_email_invite
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from ee.danswer.db.api_key import is_api_key_email_address
|
from ee.danswer.db.api_key import is_api_key_email_address
|
||||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
|
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
|
||||||
from ee.danswer.db.user_group import remove_curator_status__no_commit
|
from ee.danswer.db.user_group import remove_curator_status__no_commit
|
||||||
|
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||||
|
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@ -171,12 +181,33 @@ def bulk_invite_users(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Auth is disabled, cannot invite users"
|
status_code=400, detail="Auth is disabled, cannot invite users"
|
||||||
)
|
)
|
||||||
|
tenant_id = current_tenant_id.get()
|
||||||
|
|
||||||
normalized_emails = []
|
normalized_emails = []
|
||||||
for email in emails:
|
for email in emails:
|
||||||
email_info = validate_email(email) # can raise EmailNotValidError
|
email_info = validate_email(email) # can raise EmailNotValidError
|
||||||
normalized_emails.append(email_info.normalized) # type: ignore
|
normalized_emails.append(email_info.normalized) # type: ignore
|
||||||
|
|
||||||
|
if MULTI_TENANT:
|
||||||
|
try:
|
||||||
|
add_users_to_tenant(normalized_emails, tenant_id)
|
||||||
|
except IntegrityError as e:
|
||||||
|
if isinstance(e.orig, UniqueViolation):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="User has already been invited to a Danswer organization",
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
all_emails = list(set(normalized_emails) | set(get_invited_users()))
|
all_emails = list(set(normalized_emails) | set(get_invited_users()))
|
||||||
|
|
||||||
|
if MULTI_TENANT and ENABLE_EMAIL_INVITES:
|
||||||
|
try:
|
||||||
|
for email in all_emails:
|
||||||
|
send_user_email_invite(email, current_user)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending email invite to invited users: {e}")
|
||||||
|
|
||||||
return write_invited_users(all_emails)
|
return write_invited_users(all_emails)
|
||||||
|
|
||||||
|
|
||||||
@ -187,6 +218,10 @@ def remove_invited_user(
|
|||||||
) -> int:
|
) -> int:
|
||||||
user_emails = get_invited_users()
|
user_emails = get_invited_users()
|
||||||
remaining_users = [user for user in user_emails if user != user_email.user_email]
|
remaining_users = [user for user in user_emails if user != user_email.user_email]
|
||||||
|
|
||||||
|
tenant_id = current_tenant_id.get()
|
||||||
|
remove_users_from_tenant([user_email.user_email], tenant_id)
|
||||||
|
|
||||||
return write_invited_users(remaining_users)
|
return write_invited_users(remaining_users)
|
||||||
|
|
||||||
|
|
||||||
@ -330,6 +365,35 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
|
|||||||
return UserRoleResponse(role=user.role)
|
return UserRoleResponse(role=user.role)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_token_expiration_jwt(
|
||||||
|
user: User | None, request: Request
|
||||||
|
) -> datetime | None:
|
||||||
|
if user is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the JWT from the cookie
|
||||||
|
jwt_token = request.cookies.get("fastapiusersauth")
|
||||||
|
if not jwt_token:
|
||||||
|
logger.error("No JWT token found in cookies")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Decode the JWT
|
||||||
|
decoded_token = jwt.decode(jwt_token, options={"verify_signature": False})
|
||||||
|
|
||||||
|
# Get the 'exp' (expiration) claim from the token
|
||||||
|
exp = decoded_token.get("exp")
|
||||||
|
if exp:
|
||||||
|
return datetime.fromtimestamp(exp)
|
||||||
|
else:
|
||||||
|
logger.error("No 'exp' claim found in JWT")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error decoding JWT: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_current_token_creation(
|
def get_current_token_creation(
|
||||||
user: User | None, db_session: Session
|
user: User | None, db_session: Session
|
||||||
) -> datetime | None:
|
) -> datetime | None:
|
||||||
@ -357,6 +421,7 @@ def get_current_token_creation(
|
|||||||
|
|
||||||
@router.get("/me")
|
@router.get("/me")
|
||||||
def verify_user_logged_in(
|
def verify_user_logged_in(
|
||||||
|
request: Request,
|
||||||
user: User | None = Depends(optional_user),
|
user: User | None = Depends(optional_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> UserInfo:
|
) -> UserInfo:
|
||||||
@ -380,7 +445,9 @@ def verify_user_logged_in(
|
|||||||
detail="Access denied. User's OIDC token has expired.",
|
detail="Access denied. User's OIDC token has expired.",
|
||||||
)
|
)
|
||||||
|
|
||||||
token_created_at = get_current_token_creation(user, db_session)
|
token_created_at = (
|
||||||
|
None if MULTI_TENANT else get_current_token_creation(user, db_session)
|
||||||
|
)
|
||||||
user_info = UserInfo.from_model(
|
user_info = UserInfo.from_model(
|
||||||
user,
|
user,
|
||||||
current_token_created_at=token_created_at,
|
current_token_created_at=token_created_at,
|
||||||
|
@ -73,6 +73,7 @@ from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
|||||||
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
|
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat")
|
router = APIRouter(prefix="/chat")
|
||||||
|
@ -1,7 +1,17 @@
|
|||||||
import json
|
import json
|
||||||
|
import smtplib
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from email.mime.multipart import MIMEMultipart
|
||||||
|
from email.mime.text import MIMEText
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import SMTP_PASS
|
||||||
|
from danswer.configs.app_configs import SMTP_PORT
|
||||||
|
from danswer.configs.app_configs import SMTP_SERVER
|
||||||
|
from danswer.configs.app_configs import SMTP_USER
|
||||||
|
from danswer.configs.app_configs import WEB_DOMAIN
|
||||||
|
from danswer.db.models import User
|
||||||
|
|
||||||
|
|
||||||
class DateTimeEncoder(json.JSONEncoder):
|
class DateTimeEncoder(json.JSONEncoder):
|
||||||
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
|
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
|
||||||
@ -43,3 +53,28 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
|||||||
|
|
||||||
masked_creds[key] = mask_string(val)
|
masked_creds[key] = mask_string(val)
|
||||||
return masked_creds
|
return masked_creds
|
||||||
|
|
||||||
|
|
||||||
|
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||||
|
msg = MIMEMultipart()
|
||||||
|
msg["Subject"] = "Invitation to Join Danswer Workspace"
|
||||||
|
msg["To"] = user_email
|
||||||
|
msg["From"] = current_user.email
|
||||||
|
|
||||||
|
email_body = f"""
|
||||||
|
Hello,
|
||||||
|
|
||||||
|
You have been invited to join a workspace on Danswer.
|
||||||
|
|
||||||
|
To join the workspace, please do so at the following link:
|
||||||
|
{WEB_DOMAIN}/auth/login
|
||||||
|
|
||||||
|
Best regards,
|
||||||
|
The Danswer Team"""
|
||||||
|
|
||||||
|
msg.attach(MIMEText(email_body, "plain"))
|
||||||
|
|
||||||
|
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp_server:
|
||||||
|
smtp_server.starttls()
|
||||||
|
smtp_server.login(SMTP_USER, SMTP_PASS)
|
||||||
|
smtp_server.send_message(msg)
|
||||||
|
@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.chat.load_yamls import load_chat_yamls
|
from danswer.chat.load_yamls import load_chat_yamls
|
||||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.constants import KV_REINDEX_KEY
|
from danswer.configs.constants import KV_REINDEX_KEY
|
||||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||||
@ -98,6 +99,7 @@ def setup_danswer(db_session: Session) -> None:
|
|||||||
|
|
||||||
# Does the user need to trigger a reindexing to bring the document index
|
# Does the user need to trigger a reindexing to bring the document index
|
||||||
# into a good state, marked in the kv store
|
# into a good state, marked in the kv store
|
||||||
|
if not MULTI_TENANT:
|
||||||
mark_reindex_flag(db_session)
|
mark_reindex_flag(db_session)
|
||||||
|
|
||||||
# ensure Vespa is setup correctly
|
# ensure Vespa is setup correctly
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from danswer.background.celery.celery_app import celery_app
|
from danswer.background.celery.celery_app import celery_app
|
||||||
from danswer.background.task_utils import build_celery_task_wrapper
|
from danswer.background.task_utils import build_celery_task_wrapper
|
||||||
|
from danswer.background.update import get_all_tenant_ids
|
||||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.db.chat import delete_chat_sessions_older_than
|
from danswer.db.chat import delete_chat_sessions_older_than
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.server.settings.store import load_settings
|
from danswer.server.settings.store import load_settings
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.variable_functionality import global_version
|
from danswer.utils.variable_functionality import global_version
|
||||||
@ -32,6 +32,7 @@ from ee.danswer.external_permissions.permission_sync import (
|
|||||||
run_external_group_permission_sync,
|
run_external_group_permission_sync,
|
||||||
)
|
)
|
||||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||||
|
from shared_configs.configs import current_tenant_id
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@ -41,22 +42,26 @@ global_version.set_ee()
|
|||||||
|
|
||||||
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
|
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
|
||||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||||
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
|
def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None:
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||||
|
|
||||||
|
|
||||||
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
|
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
|
||||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||||
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
|
def sync_external_group_permissions_task(
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
cc_pair_id: int, tenant_id: str | None
|
||||||
|
) -> None:
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||||
|
|
||||||
|
|
||||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||||
def perform_ttl_management_task(retention_limit_days: int) -> None:
|
def perform_ttl_management_task(
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
retention_limit_days: int, tenant_id: str | None
|
||||||
|
) -> None:
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||||
|
|
||||||
|
|
||||||
@ -67,16 +72,16 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
|
|||||||
name="check_sync_external_doc_permissions_task",
|
name="check_sync_external_doc_permissions_task",
|
||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
)
|
)
|
||||||
def check_sync_external_doc_permissions_task() -> None:
|
def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
|
||||||
"""Runs periodically to sync external permissions"""
|
"""Runs periodically to sync external permissions"""
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||||
for cc_pair in cc_pairs:
|
for cc_pair in cc_pairs:
|
||||||
if should_perform_external_doc_permissions_check(
|
if should_perform_external_doc_permissions_check(
|
||||||
cc_pair=cc_pair, db_session=db_session
|
cc_pair=cc_pair, db_session=db_session
|
||||||
):
|
):
|
||||||
sync_external_doc_permissions_task.apply_async(
|
sync_external_doc_permissions_task.apply_async(
|
||||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -84,16 +89,16 @@ def check_sync_external_doc_permissions_task() -> None:
|
|||||||
name="check_sync_external_group_permissions_task",
|
name="check_sync_external_group_permissions_task",
|
||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
)
|
)
|
||||||
def check_sync_external_group_permissions_task() -> None:
|
def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
|
||||||
"""Runs periodically to sync external group permissions"""
|
"""Runs periodically to sync external group permissions"""
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||||
for cc_pair in cc_pairs:
|
for cc_pair in cc_pairs:
|
||||||
if should_perform_external_group_permissions_check(
|
if should_perform_external_group_permissions_check(
|
||||||
cc_pair=cc_pair, db_session=db_session
|
cc_pair=cc_pair, db_session=db_session
|
||||||
):
|
):
|
||||||
sync_external_group_permissions_task.apply_async(
|
sync_external_group_permissions_task.apply_async(
|
||||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -101,25 +106,33 @@ def check_sync_external_group_permissions_task() -> None:
|
|||||||
name="check_ttl_management_task",
|
name="check_ttl_management_task",
|
||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
)
|
)
|
||||||
def check_ttl_management_task() -> None:
|
def check_ttl_management_task(tenant_id: str | None) -> None:
|
||||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||||
to the queue"""
|
to the queue"""
|
||||||
|
token = None
|
||||||
|
if MULTI_TENANT and tenant_id is not None:
|
||||||
|
token = current_tenant_id.set(tenant_id)
|
||||||
|
|
||||||
settings = load_settings()
|
settings = load_settings()
|
||||||
retention_limit_days = settings.maximum_chat_retention_days
|
retention_limit_days = settings.maximum_chat_retention_days
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||||
perform_ttl_management_task.apply_async(
|
perform_ttl_management_task.apply_async(
|
||||||
kwargs=dict(retention_limit_days=retention_limit_days),
|
kwargs=dict(
|
||||||
|
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
if token is not None:
|
||||||
|
current_tenant_id.reset(token)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
@celery_app.task(
|
||||||
name="autogenerate_usage_report_task",
|
name="autogenerate_usage_report_task",
|
||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
)
|
)
|
||||||
def autogenerate_usage_report_task() -> None:
|
def autogenerate_usage_report_task(tenant_id: str | None) -> None:
|
||||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
create_new_usage_report(
|
create_new_usage_report(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
user_id=None,
|
user_id=None,
|
||||||
@ -130,22 +143,48 @@ def autogenerate_usage_report_task() -> None:
|
|||||||
#####
|
#####
|
||||||
# Celery Beat (Periodic Tasks) Settings
|
# Celery Beat (Periodic Tasks) Settings
|
||||||
#####
|
#####
|
||||||
celery_app.conf.beat_schedule = {
|
|
||||||
"sync-external-doc-permissions": {
|
|
||||||
|
tenant_ids = get_all_tenant_ids()
|
||||||
|
|
||||||
|
tasks_to_schedule = [
|
||||||
|
{
|
||||||
|
"name": "sync-external-doc-permissions",
|
||||||
"task": "check_sync_external_doc_permissions_task",
|
"task": "check_sync_external_doc_permissions_task",
|
||||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||||
},
|
},
|
||||||
"sync-external-group-permissions": {
|
{
|
||||||
|
"name": "sync-external-group-permissions",
|
||||||
"task": "check_sync_external_group_permissions_task",
|
"task": "check_sync_external_group_permissions_task",
|
||||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||||
},
|
},
|
||||||
"autogenerate_usage_report": {
|
{
|
||||||
|
"name": "autogenerate_usage_report",
|
||||||
"task": "autogenerate_usage_report_task",
|
"task": "autogenerate_usage_report_task",
|
||||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||||
},
|
},
|
||||||
"check-ttl-management": {
|
{
|
||||||
|
"name": "check-ttl-management",
|
||||||
"task": "check_ttl_management_task",
|
"task": "check_ttl_management_task",
|
||||||
"schedule": timedelta(hours=1),
|
"schedule": timedelta(hours=1),
|
||||||
},
|
},
|
||||||
**(celery_app.conf.beat_schedule or {}),
|
]
|
||||||
}
|
|
||||||
|
# Build the celery beat schedule dynamically
|
||||||
|
beat_schedule = {}
|
||||||
|
|
||||||
|
for tenant_id in tenant_ids:
|
||||||
|
for task in tasks_to_schedule:
|
||||||
|
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||||
|
beat_schedule[task_name] = {
|
||||||
|
"task": task["task"],
|
||||||
|
"schedule": task["schedule"],
|
||||||
|
"args": (tenant_id,), # Must pass tenant_id as an argument
|
||||||
|
}
|
||||||
|
|
||||||
|
# Include any existing beat schedules
|
||||||
|
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||||
|
beat_schedule.update(existing_beat_schedule)
|
||||||
|
|
||||||
|
# Update the Celery app configuration
|
||||||
|
celery_app.conf.beat_schedule = beat_schedule
|
||||||
|
@ -2,9 +2,13 @@ def name_chat_ttl_task(retention_limit_days: int) -> str:
|
|||||||
return f"chat_ttl_{retention_limit_days}_days"
|
return f"chat_ttl_{retention_limit_days}_days"
|
||||||
|
|
||||||
|
|
||||||
def name_sync_external_doc_permissions_task(cc_pair_id: int) -> str:
|
def name_sync_external_doc_permissions_task(
|
||||||
|
cc_pair_id: int, tenant_id: str | None = None
|
||||||
|
) -> str:
|
||||||
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
||||||
|
|
||||||
|
|
||||||
def name_sync_external_group_permissions_task(cc_pair_id: int) -> str:
|
def name_sync_external_group_permissions_task(
|
||||||
|
cc_pair_id: int, tenant_id: str | None = None
|
||||||
|
) -> str:
|
||||||
return f"sync_external_group_permissions_task__{cc_pair_id}"
|
return f"sync_external_group_permissions_task__{cc_pair_id}"
|
||||||
|
@ -4,6 +4,7 @@ from httpx_oauth.clients.openid import OpenID
|
|||||||
from danswer.auth.users import auth_backend
|
from danswer.auth.users import auth_backend
|
||||||
from danswer.auth.users import fastapi_users
|
from danswer.auth.users import fastapi_users
|
||||||
from danswer.configs.app_configs import AUTH_TYPE
|
from danswer.configs.app_configs import AUTH_TYPE
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||||
@ -24,6 +25,7 @@ from ee.danswer.server.enterprise_settings.api import (
|
|||||||
basic_router as enterprise_settings_router,
|
basic_router as enterprise_settings_router,
|
||||||
)
|
)
|
||||||
from ee.danswer.server.manage.standard_answer import router as standard_answer_router
|
from ee.danswer.server.manage.standard_answer import router as standard_answer_router
|
||||||
|
from ee.danswer.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||||
from ee.danswer.server.query_and_chat.chat_backend import (
|
from ee.danswer.server.query_and_chat.chat_backend import (
|
||||||
router as chat_router,
|
router as chat_router,
|
||||||
)
|
)
|
||||||
@ -53,6 +55,9 @@ def get_application() -> FastAPI:
|
|||||||
|
|
||||||
application = get_application_base()
|
application = get_application_base()
|
||||||
|
|
||||||
|
if MULTI_TENANT:
|
||||||
|
add_tenant_id_middleware(application, logger)
|
||||||
|
|
||||||
if AUTH_TYPE == AuthType.OIDC:
|
if AUTH_TYPE == AuthType.OIDC:
|
||||||
include_router_with_global_prefix_prepended(
|
include_router_with_global_prefix_prepended(
|
||||||
application,
|
application,
|
||||||
|
60
backend/ee/danswer/server/middleware/tenant_tracking.py
Normal file
60
backend/ee/danswer/server/middleware/tenant_tracking.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi import Response
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
|
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||||
|
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||||
|
from danswer.db.engine import is_valid_schema_name
|
||||||
|
from shared_configs.configs import current_tenant_id
|
||||||
|
|
||||||
|
|
||||||
|
def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
|
||||||
|
@app.middleware("http")
|
||||||
|
async def set_tenant_id(
|
||||||
|
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||||
|
) -> Response:
|
||||||
|
try:
|
||||||
|
logger.info(f"Request route: {request.url.path}")
|
||||||
|
|
||||||
|
if not MULTI_TENANT:
|
||||||
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
|
else:
|
||||||
|
token = request.cookies.get("tenant_details")
|
||||||
|
if token:
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, SECRET_JWT_KEY, algorithms=["HS256"]
|
||||||
|
)
|
||||||
|
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||||
|
if not is_valid_schema_name(tenant_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Invalid tenant ID format"
|
||||||
|
)
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error in set_tenant_id_middleware: {str(e)}"
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Internal server error"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
|
|
||||||
|
current_tenant_id.set(tenant_id)
|
||||||
|
logger.info(f"Middleware set current_tenant_id to: {tenant_id}")
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in tenant ID middleware: {str(e)}")
|
||||||
|
raise
|
@ -8,8 +8,11 @@ from danswer.db.engine import get_session_with_tenant
|
|||||||
from danswer.setup import setup_danswer
|
from danswer.setup import setup_danswer
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from ee.danswer.server.tenants.models import CreateTenantRequest
|
from ee.danswer.server.tenants.models import CreateTenantRequest
|
||||||
|
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||||
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
|
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
|
||||||
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
|
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
|
||||||
|
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
|
||||||
|
from shared_configs.configs import current_tenant_id
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
router = APIRouter(prefix="/tenants")
|
router = APIRouter(prefix="/tenants")
|
||||||
@ -19,9 +22,15 @@ router = APIRouter(prefix="/tenants")
|
|||||||
def create_tenant(
|
def create_tenant(
|
||||||
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
|
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
try:
|
|
||||||
tenant_id = create_tenant_request.tenant_id
|
tenant_id = create_tenant_request.tenant_id
|
||||||
|
email = create_tenant_request.initial_admin_email
|
||||||
|
token = None
|
||||||
|
if user_owns_a_tenant(email):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409, detail="User already belongs to an organization"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
if not MULTI_TENANT:
|
if not MULTI_TENANT:
|
||||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||||
|
|
||||||
@ -31,10 +40,14 @@ def create_tenant(
|
|||||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||||
|
|
||||||
run_alembic_migrations(tenant_id)
|
run_alembic_migrations(tenant_id)
|
||||||
|
token = current_tenant_id.set(tenant_id)
|
||||||
|
print("getting session", tenant_id)
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
setup_danswer(db_session)
|
setup_danswer(db_session)
|
||||||
|
|
||||||
logger.info(f"Tenant {tenant_id} created successfully")
|
logger.info(f"Tenant {tenant_id} created successfully")
|
||||||
|
add_users_to_tenant([email], tenant_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"Tenant {tenant_id} created successfully",
|
"message": f"Tenant {tenant_id} created successfully",
|
||||||
@ -44,3 +57,6 @@ def create_tenant(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
if token is not None:
|
||||||
|
current_tenant_id.reset(token)
|
||||||
|
@ -8,7 +8,9 @@ from sqlalchemy.schema import CreateSchema
|
|||||||
from alembic import command
|
from alembic import command
|
||||||
from alembic.config import Config
|
from alembic.config import Config
|
||||||
from danswer.db.engine import build_connection_string
|
from danswer.db.engine import build_connection_string
|
||||||
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.db.models import UserTenantMapping
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -61,3 +63,48 @@ def ensure_schema_exists(tenant_id: str) -> bool:
|
|||||||
db_session.execute(stmt)
|
db_session.execute(stmt)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# For now, we're implementing a primitive mapping between users and tenants.
|
||||||
|
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
|
||||||
|
def user_owns_a_tenant(email: str) -> bool:
|
||||||
|
with get_session_with_tenant("public") as db_session:
|
||||||
|
result = (
|
||||||
|
db_session.query(UserTenantMapping)
|
||||||
|
.filter(UserTenantMapping.email == email)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||||
|
with get_session_with_tenant("public") as db_session:
|
||||||
|
try:
|
||||||
|
for email in emails:
|
||||||
|
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}")
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||||
|
with get_session_with_tenant("public") as db_session:
|
||||||
|
try:
|
||||||
|
mappings_to_delete = (
|
||||||
|
db_session.query(UserTenantMapping)
|
||||||
|
.filter(
|
||||||
|
UserTenantMapping.email.in_(emails),
|
||||||
|
UserTenantMapping.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
for mapping in mappings_to_delete:
|
||||||
|
db_session.delete(mapping)
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
db_session.rollback()
|
||||||
|
@ -94,6 +94,7 @@ def generate_dummy_chunk(
|
|||||||
),
|
),
|
||||||
document_sets={document_set for document_set in document_set_names},
|
document_sets={document_set for document_set in document_set_names},
|
||||||
boost=random.randint(-1, 1),
|
boost=random.randint(-1, 1),
|
||||||
|
tenant_id="public",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import contextvars
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@ -109,3 +110,5 @@ if CORS_ALLOWED_ORIGIN_ENV:
|
|||||||
else:
|
else:
|
||||||
# If the environment variable is empty, allow all origins
|
# If the environment variable is empty, allow all origins
|
||||||
CORS_ALLOWED_ORIGIN = ["*"]
|
CORS_ALLOWED_ORIGIN = ["*"]
|
||||||
|
|
||||||
|
current_tenant_id = contextvars.ContextVar("current_tenant_id", default="public")
|
||||||
|
@ -29,6 +29,7 @@ services:
|
|||||||
- SMTP_PORT=${SMTP_PORT:-587} # For sending verification emails, if unspecified then defaults to '587'
|
- SMTP_PORT=${SMTP_PORT:-587} # For sending verification emails, if unspecified then defaults to '587'
|
||||||
- SMTP_USER=${SMTP_USER:-}
|
- SMTP_USER=${SMTP_USER:-}
|
||||||
- SMTP_PASS=${SMTP_PASS:-}
|
- SMTP_PASS=${SMTP_PASS:-}
|
||||||
|
- ENABLE_EMAIL_INVITES=${ENABLE_EMAIL_INVITES:-} # If enabled, will send users (using SMTP settings) an email to join the workspace
|
||||||
- EMAIL_FROM=${EMAIL_FROM:-}
|
- EMAIL_FROM=${EMAIL_FROM:-}
|
||||||
- OAUTH_CLIENT_ID=${OAUTH_CLIENT_ID:-}
|
- OAUTH_CLIENT_ID=${OAUTH_CLIENT_ID:-}
|
||||||
- OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-}
|
- OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-}
|
||||||
|
45
web/src/app/auth/create-account/page.tsx
Normal file
45
web/src/app/auth/create-account/page.tsx
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||||
|
import { REGISTRATION_URL } from "@/lib/constants";
|
||||||
|
import { Button } from "@tremor/react";
|
||||||
|
import Link from "next/link";
|
||||||
|
import { FiLogIn } from "react-icons/fi";
|
||||||
|
|
||||||
|
const Page = () => {
|
||||||
|
return (
|
||||||
|
<AuthFlowContainer>
|
||||||
|
<div className="flex flex-col space-y-6">
|
||||||
|
<h2 className="text-2xl font-bold text-text-900 text-center">
|
||||||
|
Account Not Found
|
||||||
|
</h2>
|
||||||
|
<p className="text-text-700 max-w-md text-center">
|
||||||
|
We couldn't find your account in our records. To access Danswer,
|
||||||
|
you need to either:
|
||||||
|
</p>
|
||||||
|
<ul className="list-disc text-left text-text-600 w-full pl-6 mx-auto">
|
||||||
|
<li>Be invited to an existing Danswer organization</li>
|
||||||
|
<li>Create a new Danswer organization</li>
|
||||||
|
</ul>
|
||||||
|
<div className="flex justify-center">
|
||||||
|
<Link
|
||||||
|
href={`${REGISTRATION_URL}/register`}
|
||||||
|
className="w-full max-w-xs"
|
||||||
|
>
|
||||||
|
<Button size="lg" icon={FiLogIn} color="indigo" className="w-full">
|
||||||
|
Create New Organization
|
||||||
|
</Button>
|
||||||
|
</Link>
|
||||||
|
</div>
|
||||||
|
<p className="text-sm text-text-500 text-center">
|
||||||
|
Have an account with a different email?{" "}
|
||||||
|
<Link href="/auth/login" className="text-indigo-600 hover:underline">
|
||||||
|
Sign in
|
||||||
|
</Link>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</AuthFlowContainer>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Page;
|
@ -1,21 +1,49 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
|
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||||
import { Button } from "@tremor/react";
|
import { Button } from "@tremor/react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { FiLogIn } from "react-icons/fi";
|
import { FiLogIn } from "react-icons/fi";
|
||||||
|
|
||||||
const Page = () => {
|
const Page = () => {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col items-center justify-center h-screen">
|
<AuthFlowContainer>
|
||||||
<div className="font-bold">
|
<div className="flex flex-col space-y-6 max-w-md mx-auto">
|
||||||
Unable to login, please try again and/or contact an administrator.
|
<h2 className="text-2xl font-bold text-text-900 text-center">
|
||||||
|
Authentication Error
|
||||||
|
</h2>
|
||||||
|
<p className="text-text-700 text-center">
|
||||||
|
We encountered an issue while attempting to log you in.
|
||||||
|
</p>
|
||||||
|
<div className="bg-red-50 border border-red-200 rounded-lg p-4 shadow-sm">
|
||||||
|
<h3 className="text-red-800 font-semibold mb-2">Possible Issues:</h3>
|
||||||
|
<ul className="space-y-2">
|
||||||
|
<li className="flex items-center text-red-700">
|
||||||
|
<div className="w-2 h-2 bg-red-500 rounded-full mr-2"></div>
|
||||||
|
Incorrect or expired login credentials
|
||||||
|
</li>
|
||||||
|
<li className="flex items-center text-red-700">
|
||||||
|
<div className="w-2 h-2 bg-red-500 rounded-full mr-2"></div>
|
||||||
|
Temporary authentication system disruption
|
||||||
|
</li>
|
||||||
|
<li className="flex items-center text-red-700">
|
||||||
|
<div className="w-2 h-2 bg-red-500 rounded-full mr-2"></div>
|
||||||
|
Account access restrictions or permissions
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
<Link href="/auth/login" className="w-fit">
|
|
||||||
<Button className="mt-4" size="xs" icon={FiLogIn}>
|
<Link href="/auth/login" className="w-full">
|
||||||
Back to login
|
<Button size="lg" icon={FiLogIn} color="indigo" className="w-full">
|
||||||
|
Return to Login Page
|
||||||
</Button>
|
</Button>
|
||||||
</Link>
|
</Link>
|
||||||
|
<p className="text-sm text-text-500 text-center">
|
||||||
|
We recommend trying again. If you continue to experience problems,
|
||||||
|
please reach out to your system administrator for assistance.
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
</AuthFlowContainer>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -6,11 +6,15 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
|
|||||||
export const LoginText = () => {
|
export const LoginText = () => {
|
||||||
const settings = useContext(SettingsContext);
|
const settings = useContext(SettingsContext);
|
||||||
|
|
||||||
if (!settings) {
|
// if (!settings) {
|
||||||
throw new Error("SettingsContext is not available");
|
// throw new Error("SettingsContext is not available");
|
||||||
}
|
// }
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>Log In to {settings?.enterpriseSettings?.application_name || "Danswer"}</>
|
<>
|
||||||
|
Log In to{" "}
|
||||||
|
{(settings && settings?.enterpriseSettings?.application_name) ||
|
||||||
|
"Danswer"}
|
||||||
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -14,6 +14,7 @@ import Link from "next/link";
|
|||||||
import { Logo } from "@/components/Logo";
|
import { Logo } from "@/components/Logo";
|
||||||
import { LoginText } from "./LoginText";
|
import { LoginText } from "./LoginText";
|
||||||
import { getSecondsUntilExpiration } from "@/lib/time";
|
import { getSecondsUntilExpiration } from "@/lib/time";
|
||||||
|
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||||
|
|
||||||
const Page = async ({
|
const Page = async ({
|
||||||
searchParams,
|
searchParams,
|
||||||
@ -51,7 +52,6 @@ const Page = async ({
|
|||||||
if (authTypeMetadata?.requiresVerification && !currentUser.is_verified) {
|
if (authTypeMetadata?.requiresVerification && !currentUser.is_verified) {
|
||||||
return redirect("/auth/waiting-on-verification");
|
return redirect("/auth/waiting-on-verification");
|
||||||
}
|
}
|
||||||
|
|
||||||
return redirect("/");
|
return redirect("/");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,16 +70,15 @@ const Page = async ({
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<main>
|
<AuthFlowContainer>
|
||||||
<div className="absolute top-10x w-full">
|
<div className="absolute top-10x w-full">
|
||||||
<HealthCheckBanner />
|
<HealthCheckBanner />
|
||||||
</div>
|
</div>
|
||||||
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
|
|
||||||
<div>
|
<div>
|
||||||
<Logo height={64} width={64} className="mx-auto w-fit" />
|
|
||||||
{authUrl && authTypeMetadata && (
|
{authUrl && authTypeMetadata && (
|
||||||
<>
|
<>
|
||||||
<h2 className="text-center text-xl text-strong font-bold mt-6">
|
<h2 className="text-center text-xl text-strong font-bold">
|
||||||
<LoginText />
|
<LoginText />
|
||||||
</h2>
|
</h2>
|
||||||
|
|
||||||
@ -108,8 +107,7 @@ const Page = async ({
|
|||||||
</Card>
|
</Card>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</AuthFlowContainer>
|
||||||
</main>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -11,6 +11,12 @@ export const GET = async (request: NextRequest) => {
|
|||||||
const response = await fetch(url.toString());
|
const response = await fetch(url.toString());
|
||||||
const setCookieHeader = response.headers.get("set-cookie");
|
const setCookieHeader = response.headers.get("set-cookie");
|
||||||
|
|
||||||
|
if (response.status === 401) {
|
||||||
|
return NextResponse.redirect(
|
||||||
|
new URL("/auth/create-account", getDomain(request))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (!setCookieHeader) {
|
if (!setCookieHeader) {
|
||||||
return NextResponse.redirect(new URL("/auth/error", getDomain(request)));
|
return NextResponse.redirect(new URL("/auth/error", getDomain(request)));
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
|||||||
import { Card, Title, Text } from "@tremor/react";
|
import { Card, Title, Text } from "@tremor/react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { Logo } from "@/components/Logo";
|
import { Logo } from "@/components/Logo";
|
||||||
|
import { CLOUD_ENABLED } from "@/lib/constants";
|
||||||
|
|
||||||
const Page = async () => {
|
const Page = async () => {
|
||||||
// catch cases where the backend is completely unreachable here
|
// catch cases where the backend is completely unreachable here
|
||||||
@ -25,6 +26,9 @@ const Page = async () => {
|
|||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log(`Some fetch failed for the login page - ${e}`);
|
console.log(`Some fetch failed for the login page - ${e}`);
|
||||||
}
|
}
|
||||||
|
if (CLOUD_ENABLED) {
|
||||||
|
return redirect("/auth/login");
|
||||||
|
}
|
||||||
|
|
||||||
// simply take the user to the home page if Auth is disabled
|
// simply take the user to the home page if Auth is disabled
|
||||||
if (authTypeMetadata?.authType === "disabled") {
|
if (authTypeMetadata?.authType === "disabled") {
|
||||||
|
@ -19,6 +19,8 @@ import { HeaderTitle } from "@/components/header/HeaderTitle";
|
|||||||
import { Logo } from "@/components/Logo";
|
import { Logo } from "@/components/Logo";
|
||||||
import { UserProvider } from "@/components/user/UserProvider";
|
import { UserProvider } from "@/components/user/UserProvider";
|
||||||
import { ProviderContextProvider } from "@/components/chat_search/ProviderContext";
|
import { ProviderContextProvider } from "@/components/chat_search/ProviderContext";
|
||||||
|
import { redirect } from "next/navigation";
|
||||||
|
import { headers } from "next/headers";
|
||||||
|
|
||||||
const inter = Inter({
|
const inter = Inter({
|
||||||
subsets: ["latin"],
|
subsets: ["latin"],
|
||||||
@ -56,8 +58,6 @@ export default async function RootLayout({
|
|||||||
const combinedSettings = await fetchSettingsSS();
|
const combinedSettings = await fetchSettingsSS();
|
||||||
|
|
||||||
if (!combinedSettings) {
|
if (!combinedSettings) {
|
||||||
// Just display a simple full page error if fetching fails.
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<html lang="en" className={`${inter.variable} font-sans`}>
|
<html lang="en" className={`${inter.variable} font-sans`}>
|
||||||
<Head>
|
<Head>
|
||||||
|
16
web/src/components/auth/AuthFlowContainer.tsx
Normal file
16
web/src/components/auth/AuthFlowContainer.tsx
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import { Logo } from "../Logo";
|
||||||
|
|
||||||
|
export default function AuthFlowContainer({
|
||||||
|
children,
|
||||||
|
}: {
|
||||||
|
children: React.ReactNode;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col items-center justify-center min-h-screen bg-background">
|
||||||
|
<div className="w-full max-w-md p-8 gap-y-4 bg-white flex items-center flex-col rounded-xl shadow-lg border border-bacgkround-100">
|
||||||
|
<Logo width={70} height={70} />
|
||||||
|
{children}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
@ -40,7 +40,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
|
|||||||
|
|
||||||
let settings: Settings;
|
let settings: Settings;
|
||||||
if (!results[0].ok) {
|
if (!results[0].ok) {
|
||||||
if (results[0].status === 403) {
|
if (results[0].status === 403 || results[0].status === 401) {
|
||||||
settings = {
|
settings = {
|
||||||
gpu_enabled: false,
|
gpu_enabled: false,
|
||||||
chat_page_enabled: true,
|
chat_page_enabled: true,
|
||||||
@ -62,7 +62,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
|
|||||||
let enterpriseSettings: EnterpriseSettings | null = null;
|
let enterpriseSettings: EnterpriseSettings | null = null;
|
||||||
if (tasks.length > 1) {
|
if (tasks.length > 1) {
|
||||||
if (!results[1].ok) {
|
if (!results[1].ok) {
|
||||||
if (results[1].status !== 403) {
|
if (results[1].status !== 403 && results[1].status !== 401) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`fetchEnterpriseSettingsSS failed: status=${results[1].status} body=${await results[1].text()}`
|
`fetchEnterpriseSettingsSS failed: status=${results[1].status} body=${await results[1].text()}`
|
||||||
);
|
);
|
||||||
|
@ -55,3 +55,7 @@ export const CUSTOM_ANALYTICS_ENABLED = process.env.CUSTOM_ANALYTICS_SECRET_KEY
|
|||||||
|
|
||||||
export const DISABLE_LLM_DOC_RELEVANCE =
|
export const DISABLE_LLM_DOC_RELEVANCE =
|
||||||
process.env.DISABLE_LLM_DOC_RELEVANCE?.toLowerCase() === "true";
|
process.env.DISABLE_LLM_DOC_RELEVANCE?.toLowerCase() === "true";
|
||||||
|
|
||||||
|
export const CLOUD_ENABLED = process.env.NEXT_PUBLIC_CLOUD_ENABLED;
|
||||||
|
export const REGISTRATION_URL =
|
||||||
|
process.env.INTERNAL_URL || "http://127.0.0.1:3001";
|
||||||
|
Loading…
x
Reference in New Issue
Block a user