From 7cc51376f23e1bbd108469483ff74669e6550b8b Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 10 Jun 2024 15:05:20 -0700 Subject: [PATCH] Allow basic seeding of Danswer via env variable --- backend/danswer/auth/users.py | 7 ++- backend/danswer/db/auth.py | 17 +++++- .../danswer/utils/variable_functionality.py | 19 +++++++ backend/ee/danswer/auth/users.py | 8 +++ backend/ee/danswer/main.py | 5 ++ .../server/query_and_chat/query_backend.py | 2 + backend/ee/danswer/server/seeding.py | 55 +++++++++++++++++++ 7 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 backend/ee/danswer/server/seeding.py diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index dd7bc9f77..06c0841d8 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -46,6 +46,7 @@ from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN from danswer.configs.constants import DANSWER_API_KEY_PREFIX from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER from danswer.db.auth import get_access_token_db +from danswer.db.auth import get_default_admin_user_emails from danswer.db.auth import get_user_count from danswer.db.auth import get_user_db from danswer.db.engine import get_session @@ -54,7 +55,9 @@ from danswer.db.models import User from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType -from danswer.utils.variable_functionality import fetch_versioned_implementation +from danswer.utils.variable_functionality import ( + fetch_versioned_implementation, +) logger = setup_logger() @@ -148,7 +151,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): verify_email_domain(user_create.email) if hasattr(user_create, "role"): user_count = await get_user_count() - if user_count == 0: + if user_count == 0 or user_create.email in get_default_admin_user_emails(): user_create.role = UserRole.ADMIN else: user_create.role = UserRole.BASIC diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py index 6d726c2f9..161fdc8f1 100644 --- a/backend/danswer/db/auth.py +++ b/backend/danswer/db/auth.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from collections.abc import Callable from typing import Any from typing import Dict @@ -16,6 +17,20 @@ from danswer.db.engine import get_sqlalchemy_async_engine from danswer.db.models import AccessToken from danswer.db.models import OAuthAccount from danswer.db.models import User +from danswer.utils.variable_functionality import ( + fetch_versioned_implementation_with_fallback, +) + + +def get_default_admin_user_emails() -> list[str]: + """Returns a list of emails who should default to Admin role. + Only used in the EE version. For MIT, just return empty list.""" + get_default_admin_user_emails_fn: Callable[ + [], list[str] + ] = fetch_versioned_implementation_with_fallback( + "danswer.auth.users", "get_default_admin_user_emails_", lambda: [] + ) + return get_default_admin_user_emails_fn() async def get_user_count() -> int: @@ -32,7 +47,7 @@ async def get_user_count() -> int: class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase): async def create(self, create_dict: Dict[str, Any]) -> UP: user_count = await get_user_count() - if user_count == 0: + if user_count == 0 or create_dict["email"] in get_default_admin_user_emails(): create_dict["role"] = UserRole.ADMIN else: create_dict["role"] = UserRole.BASIC diff --git a/backend/danswer/utils/variable_functionality.py b/backend/danswer/utils/variable_functionality.py index 61414effd..d813c10b4 100644 --- a/backend/danswer/utils/variable_functionality.py +++ b/backend/danswer/utils/variable_functionality.py @@ -1,6 +1,7 @@ import functools import importlib from typing import Any +from typing import TypeVar from danswer.utils.logger import setup_logger @@ -36,3 +37,21 @@ def fetch_versioned_implementation(module: str, attribute: str) -> Any: return getattr(importlib.import_module(module), attribute) raise + + +T = TypeVar("T") + + +def fetch_versioned_implementation_with_fallback( + module: str, attribute: str, fallback: T +) -> T: + try: + return fetch_versioned_implementation(module, attribute) + except Exception as e: + logger.warning( + "Failed to fetch versioned implementation for %s.%s: %s", + module, + attribute, + e, + ) + return fallback diff --git a/backend/ee/danswer/auth/users.py b/backend/ee/danswer/auth/users.py index 899cb6353..f5f5dbd58 100644 --- a/backend/ee/danswer/auth/users.py +++ b/backend/ee/danswer/auth/users.py @@ -11,6 +11,7 @@ from danswer.utils.logger import setup_logger from ee.danswer.auth.api_key import get_hashed_api_key_from_request from ee.danswer.db.api_key import fetch_user_for_api_key from ee.danswer.db.saml import get_saml_account +from ee.danswer.server.seeding import get_seed_config from ee.danswer.utils.secrets import extract_hashed_cookie logger = setup_logger() @@ -55,3 +56,10 @@ def api_key_dep(request: Request, db_session: Session = Depends(get_session)) -> raise HTTPException(status_code=401, detail="Invalid API key") return user + + +def get_default_admin_user_emails_() -> list[str]: + seed_config = get_seed_config() + if seed_config and seed_config.admin_user_emails: + return seed_config.admin_user_emails + return [] diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index a982657f4..30fcfb909 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -34,6 +34,7 @@ from ee.danswer.server.query_and_chat.query_backend import ( ) from ee.danswer.server.query_history.api import router as query_history_router from ee.danswer.server.saml import router as saml_router +from ee.danswer.server.seeding import seed_db from ee.danswer.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) @@ -100,6 +101,10 @@ def get_ee_application() -> FastAPI: # Ensure all routes have auth enabled or are explicitly marked as public check_ee_router_auth(application) + # seed the Danswer environment with LLMs, Assistants, etc. based on an optional + # environment variable. Used to automate deployment for multiple environments. + seed_db() + return application diff --git a/backend/ee/danswer/server/query_and_chat/query_backend.py b/backend/ee/danswer/server/query_and_chat/query_backend.py index 8828391ad..9793479fc 100644 --- a/backend/ee/danswer/server/query_and_chat/query_backend.py +++ b/backend/ee/danswer/server/query_and_chat/query_backend.py @@ -41,6 +41,7 @@ def handle_search_request( query = search_request.message logger.info(f"Received document search query: {query}") + llm = get_default_llm() search_pipeline = SearchPipeline( search_request=SearchRequest( query=query, @@ -57,6 +58,7 @@ def handle_search_request( full_doc=search_request.full_doc, ), user=user, + llm=llm, db_session=db_session, bypass_acl=False, ) diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py new file mode 100644 index 000000000..225ffdc90 --- /dev/null +++ b/backend/ee/danswer/server/seeding.py @@ -0,0 +1,55 @@ +import os + +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.db.engine import get_session_context_manager +from danswer.db.llm import fetch_existing_llm_providers +from danswer.db.llm import upsert_llm_provider +from danswer.server.manage.llm.models import LLMProviderUpsertRequest +from danswer.utils.logger import setup_logger + +logger = setup_logger() + +_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION" + + +class SeedConfiguration(BaseModel): + llms: list[LLMProviderUpsertRequest] | None = None + admin_user_emails: list[str] | None = None + + +def _parse_env() -> SeedConfiguration | None: + seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME) + if seed_config_str is None: + return None + + seed_config = SeedConfiguration.parse_raw(seed_config_str) + return seed_config + + +def _seed_llms( + db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest] +) -> None: + # don't seed LLMs if we've already done this + existing_llms = fetch_existing_llm_providers(db_session) + if existing_llms: + return + + logger.info("Seeding LLMs") + for llm_upsert_request in llm_upsert_requests: + upsert_llm_provider(db_session, llm_upsert_request) + + +def get_seed_config() -> SeedConfiguration | None: + return _parse_env() + + +def seed_db() -> None: + seed_config = _parse_env() + if seed_config is None: + return + + with get_session_context_manager() as db_session: + if seed_config.llms is not None: + _seed_llms(db_session, seed_config.llms)