mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-13 13:20:15 +02:00
254 lines
9.4 KiB
Python
254 lines
9.4 KiB
Python
import json
|
|
import os
|
|
from copy import deepcopy
|
|
from typing import List
|
|
from typing import Optional
|
|
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from ee.onyx.db.standard_answer import (
|
|
create_initial_default_standard_answer_category,
|
|
)
|
|
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
|
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
|
from ee.onyx.server.enterprise_settings.models import NavigationItem
|
|
from ee.onyx.server.enterprise_settings.store import store_analytics_script
|
|
from ee.onyx.server.enterprise_settings.store import (
|
|
store_settings as store_ee_settings,
|
|
)
|
|
from ee.onyx.server.enterprise_settings.store import upload_logo
|
|
from onyx.context.search.enums import RecencyBiasSetting
|
|
from onyx.db.engine import get_session_context_manager
|
|
from onyx.db.llm import update_default_provider
|
|
from onyx.db.llm import upsert_llm_provider
|
|
from onyx.db.models import Tool
|
|
from onyx.db.persona import upsert_persona
|
|
from onyx.server.features.persona.models import CreatePersonaRequest
|
|
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
|
from onyx.server.settings.models import Settings
|
|
from onyx.server.settings.store import store_settings as store_base_settings
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
|
|
class CustomToolSeed(BaseModel):
|
|
name: str
|
|
description: str
|
|
definition_path: str
|
|
custom_headers: Optional[List[dict]] = None
|
|
display_name: Optional[str] = None
|
|
in_code_tool_id: Optional[str] = None
|
|
user_id: Optional[str] = None
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
|
|
|
|
|
|
class NavigationItemSeed(BaseModel):
|
|
link: str
|
|
title: str
|
|
# NOTE: SVG at this path must not have a width / height specified
|
|
svg_path: str
|
|
|
|
|
|
class SeedConfiguration(BaseModel):
|
|
llms: list[LLMProviderUpsertRequest] | None = None
|
|
admin_user_emails: list[str] | None = None
|
|
seeded_logo_path: str | None = None
|
|
personas: list[CreatePersonaRequest] | None = None
|
|
settings: Settings | None = None
|
|
enterprise_settings: EnterpriseSettings | None = None
|
|
|
|
# allows for specifying custom navigation items that have your own custom SVG logos
|
|
nav_item_overrides: list[NavigationItemSeed] | None = None
|
|
|
|
# Use existing `CUSTOM_ANALYTICS_SECRET_KEY` for reference
|
|
analytics_script_path: str | None = None
|
|
custom_tools: List[CustomToolSeed] | None = None
|
|
|
|
|
|
def _parse_env() -> SeedConfiguration | None:
|
|
seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME)
|
|
if not seed_config_str:
|
|
return None
|
|
seed_config = SeedConfiguration.model_validate_json(seed_config_str)
|
|
return seed_config
|
|
|
|
|
|
def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None:
|
|
if tools:
|
|
logger.notice("Seeding Custom Tools")
|
|
for tool in tools:
|
|
try:
|
|
logger.debug(f"Attempting to seed tool: {tool.name}")
|
|
logger.debug(f"Reading definition from: {tool.definition_path}")
|
|
with open(tool.definition_path, "r") as file:
|
|
file_content = file.read()
|
|
if not file_content.strip():
|
|
raise ValueError("File is empty")
|
|
openapi_schema = json.loads(file_content)
|
|
db_tool = Tool(
|
|
name=tool.name,
|
|
description=tool.description,
|
|
openapi_schema=openapi_schema,
|
|
custom_headers=tool.custom_headers,
|
|
display_name=tool.display_name,
|
|
in_code_tool_id=tool.in_code_tool_id,
|
|
user_id=tool.user_id,
|
|
)
|
|
db_session.add(db_tool)
|
|
logger.debug(f"Successfully added tool: {tool.name}")
|
|
except FileNotFoundError:
|
|
logger.error(
|
|
f"Definition file not found for tool {tool.name}: {tool.definition_path}"
|
|
)
|
|
except json.JSONDecodeError as e:
|
|
logger.error(
|
|
f"Invalid JSON in definition file for tool {tool.name}: {str(e)}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to seed tool {tool.name}: {str(e)}")
|
|
db_session.commit()
|
|
logger.notice(f"Successfully seeded {len(tools)} Custom Tools")
|
|
|
|
|
|
def _seed_llms(
|
|
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
|
|
) -> None:
|
|
if llm_upsert_requests:
|
|
logger.notice("Seeding LLMs")
|
|
seeded_providers = [
|
|
upsert_llm_provider(llm_upsert_request, db_session)
|
|
for llm_upsert_request in llm_upsert_requests
|
|
]
|
|
update_default_provider(
|
|
provider_id=seeded_providers[0].id, db_session=db_session
|
|
)
|
|
|
|
|
|
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
|
|
if personas:
|
|
logger.notice("Seeding Personas")
|
|
for persona in personas:
|
|
if not persona.prompt_ids:
|
|
raise ValueError(
|
|
f"Invalid Persona with name {persona.name}; no prompts exist"
|
|
)
|
|
|
|
upsert_persona(
|
|
user=None, # Seeding is done as admin
|
|
name=persona.name,
|
|
description=persona.description,
|
|
num_chunks=(
|
|
persona.num_chunks if persona.num_chunks is not None else 0.0
|
|
),
|
|
llm_relevance_filter=persona.llm_relevance_filter,
|
|
llm_filter_extraction=persona.llm_filter_extraction,
|
|
recency_bias=RecencyBiasSetting.AUTO,
|
|
prompt_ids=persona.prompt_ids,
|
|
document_set_ids=persona.document_set_ids,
|
|
llm_model_provider_override=persona.llm_model_provider_override,
|
|
llm_model_version_override=persona.llm_model_version_override,
|
|
starter_messages=persona.starter_messages,
|
|
is_public=persona.is_public,
|
|
db_session=db_session,
|
|
tool_ids=persona.tool_ids,
|
|
display_priority=persona.display_priority,
|
|
)
|
|
|
|
|
|
def _seed_settings(settings: Settings) -> None:
|
|
logger.notice("Seeding Settings")
|
|
try:
|
|
store_base_settings(settings)
|
|
logger.notice("Successfully seeded Settings")
|
|
except ValueError as e:
|
|
logger.error(f"Failed to seed Settings: {str(e)}")
|
|
|
|
|
|
def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
|
|
if (
|
|
seed_config.enterprise_settings is not None
|
|
or seed_config.nav_item_overrides is not None
|
|
):
|
|
final_enterprise_settings = (
|
|
deepcopy(seed_config.enterprise_settings)
|
|
if seed_config.enterprise_settings
|
|
else EnterpriseSettings()
|
|
)
|
|
|
|
final_nav_items = final_enterprise_settings.custom_nav_items
|
|
if seed_config.nav_item_overrides is not None:
|
|
final_nav_items = []
|
|
for item in seed_config.nav_item_overrides:
|
|
with open(item.svg_path, "r") as file:
|
|
svg_content = file.read().strip()
|
|
|
|
final_nav_items.append(
|
|
NavigationItem(
|
|
link=item.link,
|
|
title=item.title,
|
|
svg_logo=svg_content,
|
|
)
|
|
)
|
|
|
|
final_enterprise_settings.custom_nav_items = final_nav_items
|
|
|
|
logger.notice("Seeding enterprise settings")
|
|
store_ee_settings(final_enterprise_settings)
|
|
|
|
|
|
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
|
|
if logo_path:
|
|
logger.notice("Uploading logo")
|
|
upload_logo(db_session=db_session, file=logo_path)
|
|
|
|
|
|
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
|
|
custom_analytics_secret_key = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")
|
|
if seed_config.analytics_script_path and custom_analytics_secret_key:
|
|
logger.notice("Seeding analytics script")
|
|
try:
|
|
with open(seed_config.analytics_script_path, "r") as file:
|
|
script_content = file.read()
|
|
analytics_script = AnalyticsScriptUpload(
|
|
script=script_content, secret_key=custom_analytics_secret_key
|
|
)
|
|
store_analytics_script(analytics_script)
|
|
except FileNotFoundError:
|
|
logger.error(
|
|
f"Analytics script file not found: {seed_config.analytics_script_path}"
|
|
)
|
|
except ValueError as e:
|
|
logger.error(f"Failed to seed analytics script: {str(e)}")
|
|
|
|
|
|
def get_seed_config() -> SeedConfiguration | None:
|
|
return _parse_env()
|
|
|
|
|
|
def seed_db() -> None:
|
|
seed_config = _parse_env()
|
|
if seed_config is None:
|
|
logger.debug("No seeding configuration file passed")
|
|
return
|
|
|
|
with get_session_context_manager() as db_session:
|
|
if seed_config.llms is not None:
|
|
_seed_llms(db_session, seed_config.llms)
|
|
if seed_config.personas is not None:
|
|
_seed_personas(db_session, seed_config.personas)
|
|
if seed_config.settings is not None:
|
|
_seed_settings(seed_config.settings)
|
|
if seed_config.custom_tools is not None:
|
|
_seed_custom_tools(db_session, seed_config.custom_tools)
|
|
|
|
_seed_logo(db_session, seed_config.seeded_logo_path)
|
|
_seed_enterprise_settings(seed_config)
|
|
_seed_analytics_script(seed_config)
|
|
|
|
logger.notice("Verifying default standard answer category exists.")
|
|
create_initial_default_standard_answer_category(db_session)
|