mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-03 09:28:25 +02:00
Allow basic seeding of Danswer via env variable
This commit is contained in:
parent
7278d45552
commit
7cc51376f2
@ -46,6 +46,7 @@ from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_session
|
||||
@ -54,7 +55,9 @@ from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -148,7 +151,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0:
|
||||
if user_count == 0 or user_create.email in get_default_admin_user_emails():
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
@ -16,6 +17,20 @@ from danswer.db.engine import get_sqlalchemy_async_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
def get_default_admin_user_emails() -> list[str]:
|
||||
"""Returns a list of emails who should default to Admin role.
|
||||
Only used in the EE version. For MIT, just return empty list."""
|
||||
get_default_admin_user_emails_fn: Callable[
|
||||
[], list[str]
|
||||
] = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.auth.users", "get_default_admin_user_emails_", lambda: []
|
||||
)
|
||||
return get_default_admin_user_emails_fn()
|
||||
|
||||
|
||||
async def get_user_count() -> int:
|
||||
@ -32,7 +47,7 @@ async def get_user_count() -> int:
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
|
||||
async def create(self, create_dict: Dict[str, Any]) -> UP:
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0:
|
||||
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
|
||||
create_dict["role"] = UserRole.ADMIN
|
||||
else:
|
||||
create_dict["role"] = UserRole.BASIC
|
||||
|
@ -1,6 +1,7 @@
|
||||
import functools
|
||||
import importlib
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@ -36,3 +37,21 @@ def fetch_versioned_implementation(module: str, attribute: str) -> Any:
|
||||
return getattr(importlib.import_module(module), attribute)
|
||||
|
||||
raise
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def fetch_versioned_implementation_with_fallback(
|
||||
module: str, attribute: str, fallback: T
|
||||
) -> T:
|
||||
try:
|
||||
return fetch_versioned_implementation(module, attribute)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to fetch versioned implementation for %s.%s: %s",
|
||||
module,
|
||||
attribute,
|
||||
e,
|
||||
)
|
||||
return fallback
|
||||
|
@ -11,6 +11,7 @@ from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.auth.api_key import get_hashed_api_key_from_request
|
||||
from ee.danswer.db.api_key import fetch_user_for_api_key
|
||||
from ee.danswer.db.saml import get_saml_account
|
||||
from ee.danswer.server.seeding import get_seed_config
|
||||
from ee.danswer.utils.secrets import extract_hashed_cookie
|
||||
|
||||
logger = setup_logger()
|
||||
@ -55,3 +56,10 @@ def api_key_dep(request: Request, db_session: Session = Depends(get_session)) ->
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
seed_config = get_seed_config()
|
||||
if seed_config and seed_config.admin_user_emails:
|
||||
return seed_config.admin_user_emails
|
||||
return []
|
||||
|
@ -34,6 +34,7 @@ from ee.danswer.server.query_and_chat.query_backend import (
|
||||
)
|
||||
from ee.danswer.server.query_history.api import router as query_history_router
|
||||
from ee.danswer.server.saml import router as saml_router
|
||||
from ee.danswer.server.seeding import seed_db
|
||||
from ee.danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
@ -100,6 +101,10 @@ def get_ee_application() -> FastAPI:
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
# seed the Danswer environment with LLMs, Assistants, etc. based on an optional
|
||||
# environment variable. Used to automate deployment for multiple environments.
|
||||
seed_db()
|
||||
|
||||
return application
|
||||
|
||||
|
||||
|
@ -41,6 +41,7 @@ def handle_search_request(
|
||||
query = search_request.message
|
||||
logger.info(f"Received document search query: {query}")
|
||||
|
||||
llm = get_default_llm()
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
@ -57,6 +58,7 @@ def handle_search_request(
|
||||
full_doc=search_request.full_doc,
|
||||
),
|
||||
user=user,
|
||||
llm=llm,
|
||||
db_session=db_session,
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
55
backend/ee/danswer/server/seeding.py
Normal file
55
backend/ee/danswer/server/seeding.py
Normal file
@ -0,0 +1,55 @@
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
|
||||
|
||||
|
||||
class SeedConfiguration(BaseModel):
|
||||
llms: list[LLMProviderUpsertRequest] | None = None
|
||||
admin_user_emails: list[str] | None = None
|
||||
|
||||
|
||||
def _parse_env() -> SeedConfiguration | None:
|
||||
seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME)
|
||||
if seed_config_str is None:
|
||||
return None
|
||||
|
||||
seed_config = SeedConfiguration.parse_raw(seed_config_str)
|
||||
return seed_config
|
||||
|
||||
|
||||
def _seed_llms(
|
||||
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
|
||||
) -> None:
|
||||
# don't seed LLMs if we've already done this
|
||||
existing_llms = fetch_existing_llm_providers(db_session)
|
||||
if existing_llms:
|
||||
return
|
||||
|
||||
logger.info("Seeding LLMs")
|
||||
for llm_upsert_request in llm_upsert_requests:
|
||||
upsert_llm_provider(db_session, llm_upsert_request)
|
||||
|
||||
|
||||
def get_seed_config() -> SeedConfiguration | None:
|
||||
return _parse_env()
|
||||
|
||||
|
||||
def seed_db() -> None:
|
||||
seed_config = _parse_env()
|
||||
if seed_config is None:
|
||||
return
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
if seed_config.llms is not None:
|
||||
_seed_llms(db_session, seed_config.llms)
|
Loading…
x
Reference in New Issue
Block a user