Allow basic seeding of Danswer via env variable

This commit is contained in:
Weves 2024-06-10 15:05:20 -07:00 committed by Chris Weaver
parent 7278d45552
commit 7cc51376f2
7 changed files with 110 additions and 3 deletions

View File

@ -46,6 +46,7 @@ from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_session
@ -54,7 +55,9 @@ from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation,
)
logger = setup_logger()
@ -148,7 +151,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0:
if user_count == 0 or user_create.email in get_default_admin_user_emails():
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC

View File

@ -1,4 +1,5 @@
from collections.abc import AsyncGenerator
from collections.abc import Callable
from typing import Any
from typing import Dict
@ -16,6 +17,20 @@ from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
def get_default_admin_user_emails() -> list[str]:
"""Returns a list of emails who should default to Admin role.
Only used in the EE version. For MIT, just return empty list."""
get_default_admin_user_emails_fn: Callable[
[], list[str]
] = fetch_versioned_implementation_with_fallback(
"danswer.auth.users", "get_default_admin_user_emails_", lambda: []
)
return get_default_admin_user_emails_fn()
async def get_user_count() -> int:
@ -32,7 +47,7 @@ async def get_user_count() -> int:
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
async def create(self, create_dict: Dict[str, Any]) -> UP:
user_count = await get_user_count()
if user_count == 0:
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
create_dict["role"] = UserRole.ADMIN
else:
create_dict["role"] = UserRole.BASIC

View File

@ -1,6 +1,7 @@
import functools
import importlib
from typing import Any
from typing import TypeVar
from danswer.utils.logger import setup_logger
@ -36,3 +37,21 @@ def fetch_versioned_implementation(module: str, attribute: str) -> Any:
return getattr(importlib.import_module(module), attribute)
raise
T = TypeVar("T")
def fetch_versioned_implementation_with_fallback(
module: str, attribute: str, fallback: T
) -> T:
try:
return fetch_versioned_implementation(module, attribute)
except Exception as e:
logger.warning(
"Failed to fetch versioned implementation for %s.%s: %s",
module,
attribute,
e,
)
return fallback

View File

@ -11,6 +11,7 @@ from danswer.utils.logger import setup_logger
from ee.danswer.auth.api_key import get_hashed_api_key_from_request
from ee.danswer.db.api_key import fetch_user_for_api_key
from ee.danswer.db.saml import get_saml_account
from ee.danswer.server.seeding import get_seed_config
from ee.danswer.utils.secrets import extract_hashed_cookie
logger = setup_logger()
@ -55,3 +56,10 @@ def api_key_dep(request: Request, db_session: Session = Depends(get_session)) ->
raise HTTPException(status_code=401, detail="Invalid API key")
return user
def get_default_admin_user_emails_() -> list[str]:
seed_config = get_seed_config()
if seed_config and seed_config.admin_user_emails:
return seed_config.admin_user_emails
return []

View File

@ -34,6 +34,7 @@ from ee.danswer.server.query_and_chat.query_backend import (
)
from ee.danswer.server.query_history.api import router as query_history_router
from ee.danswer.server.saml import router as saml_router
from ee.danswer.server.seeding import seed_db
from ee.danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
@ -100,6 +101,10 @@ def get_ee_application() -> FastAPI:
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)
# seed the Danswer environment with LLMs, Assistants, etc. based on an optional
# environment variable. Used to automate deployment for multiple environments.
seed_db()
return application

View File

@ -41,6 +41,7 @@ def handle_search_request(
query = search_request.message
logger.info(f"Received document search query: {query}")
llm = get_default_llm()
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
@ -57,6 +58,7 @@ def handle_search_request(
full_doc=search_request.full_doc,
),
user=user,
llm=llm,
db_session=db_session,
bypass_acl=False,
)

View File

@ -0,0 +1,55 @@
import os
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import upsert_llm_provider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
class SeedConfiguration(BaseModel):
llms: list[LLMProviderUpsertRequest] | None = None
admin_user_emails: list[str] | None = None
def _parse_env() -> SeedConfiguration | None:
seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME)
if seed_config_str is None:
return None
seed_config = SeedConfiguration.parse_raw(seed_config_str)
return seed_config
def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
# don't seed LLMs if we've already done this
existing_llms = fetch_existing_llm_providers(db_session)
if existing_llms:
return
logger.info("Seeding LLMs")
for llm_upsert_request in llm_upsert_requests:
upsert_llm_provider(db_session, llm_upsert_request)
def get_seed_config() -> SeedConfiguration | None:
return _parse_env()
def seed_db() -> None:
seed_config = _parse_env()
if seed_config is None:
return
with get_session_context_manager() as db_session:
if seed_config.llms is not None:
_seed_llms(db_session, seed_config.llms)