Allow basic seeding of Danswer via env variable

This commit is contained in:
Weves
2024-06-10 15:05:20 -07:00
committed by Chris Weaver
parent 7278d45552
commit 7cc51376f2
7 changed files with 110 additions and 3 deletions

View File

@@ -11,6 +11,7 @@ from danswer.utils.logger import setup_logger
from ee.danswer.auth.api_key import get_hashed_api_key_from_request
from ee.danswer.db.api_key import fetch_user_for_api_key
from ee.danswer.db.saml import get_saml_account
from ee.danswer.server.seeding import get_seed_config
from ee.danswer.utils.secrets import extract_hashed_cookie
logger = setup_logger()
@@ -55,3 +56,10 @@ def api_key_dep(request: Request, db_session: Session = Depends(get_session)) ->
raise HTTPException(status_code=401, detail="Invalid API key")
return user
def get_default_admin_user_emails_() -> list[str]:
seed_config = get_seed_config()
if seed_config and seed_config.admin_user_emails:
return seed_config.admin_user_emails
return []

View File

@@ -34,6 +34,7 @@ from ee.danswer.server.query_and_chat.query_backend import (
)
from ee.danswer.server.query_history.api import router as query_history_router
from ee.danswer.server.saml import router as saml_router
from ee.danswer.server.seeding import seed_db
from ee.danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
@@ -100,6 +101,10 @@ def get_ee_application() -> FastAPI:
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)
# seed the Danswer environment with LLMs, Assistants, etc. based on an optional
# environment variable. Used to automate deployment for multiple environments.
seed_db()
return application

View File

@@ -41,6 +41,7 @@ def handle_search_request(
query = search_request.message
logger.info(f"Received document search query: {query}")
llm = get_default_llm()
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
@@ -57,6 +58,7 @@ def handle_search_request(
full_doc=search_request.full_doc,
),
user=user,
llm=llm,
db_session=db_session,
bypass_acl=False,
)

View File

@@ -0,0 +1,55 @@
import os
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import upsert_llm_provider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
class SeedConfiguration(BaseModel):
llms: list[LLMProviderUpsertRequest] | None = None
admin_user_emails: list[str] | None = None
def _parse_env() -> SeedConfiguration | None:
seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME)
if seed_config_str is None:
return None
seed_config = SeedConfiguration.parse_raw(seed_config_str)
return seed_config
def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
# don't seed LLMs if we've already done this
existing_llms = fetch_existing_llm_providers(db_session)
if existing_llms:
return
logger.info("Seeding LLMs")
for llm_upsert_request in llm_upsert_requests:
upsert_llm_provider(db_session, llm_upsert_request)
def get_seed_config() -> SeedConfiguration | None:
return _parse_env()
def seed_db() -> None:
seed_config = _parse_env()
if seed_config is None:
return
with get_session_context_manager() as db_session:
if seed_config.llms is not None:
_seed_llms(db_session, seed_config.llms)