2024-12-13 09:56:10 -08:00

254 lines
9.4 KiB
Python

import json
import os
from copy import deepcopy
from typing import List
from typing import Optional
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.db.standard_answer import (
create_initial_default_standard_answer_category,
)
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.models import NavigationItem
from ee.onyx.server.enterprise_settings.store import store_analytics_script
from ee.onyx.server.enterprise_settings.store import (
store_settings as store_ee_settings,
)
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine import get_session_context_manager
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
from onyx.db.persona import upsert_persona
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.settings.models import Settings
from onyx.server.settings.store import store_settings as store_base_settings
from onyx.utils.logger import setup_logger
class CustomToolSeed(BaseModel):
name: str
description: str
definition_path: str
custom_headers: Optional[List[dict]] = None
display_name: Optional[str] = None
in_code_tool_id: Optional[str] = None
user_id: Optional[str] = None
logger = setup_logger()
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
class NavigationItemSeed(BaseModel):
link: str
title: str
# NOTE: SVG at this path must not have a width / height specified
svg_path: str
class SeedConfiguration(BaseModel):
llms: list[LLMProviderUpsertRequest] | None = None
admin_user_emails: list[str] | None = None
seeded_logo_path: str | None = None
personas: list[CreatePersonaRequest] | None = None
settings: Settings | None = None
enterprise_settings: EnterpriseSettings | None = None
# allows for specifying custom navigation items that have your own custom SVG logos
nav_item_overrides: list[NavigationItemSeed] | None = None
# Use existing `CUSTOM_ANALYTICS_SECRET_KEY` for reference
analytics_script_path: str | None = None
custom_tools: List[CustomToolSeed] | None = None
def _parse_env() -> SeedConfiguration | None:
seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME)
if not seed_config_str:
return None
seed_config = SeedConfiguration.model_validate_json(seed_config_str)
return seed_config
def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None:
if tools:
logger.notice("Seeding Custom Tools")
for tool in tools:
try:
logger.debug(f"Attempting to seed tool: {tool.name}")
logger.debug(f"Reading definition from: {tool.definition_path}")
with open(tool.definition_path, "r") as file:
file_content = file.read()
if not file_content.strip():
raise ValueError("File is empty")
openapi_schema = json.loads(file_content)
db_tool = Tool(
name=tool.name,
description=tool.description,
openapi_schema=openapi_schema,
custom_headers=tool.custom_headers,
display_name=tool.display_name,
in_code_tool_id=tool.in_code_tool_id,
user_id=tool.user_id,
)
db_session.add(db_tool)
logger.debug(f"Successfully added tool: {tool.name}")
except FileNotFoundError:
logger.error(
f"Definition file not found for tool {tool.name}: {tool.definition_path}"
)
except json.JSONDecodeError as e:
logger.error(
f"Invalid JSON in definition file for tool {tool.name}: {str(e)}"
)
except Exception as e:
logger.error(f"Failed to seed tool {tool.name}: {str(e)}")
db_session.commit()
logger.notice(f"Successfully seeded {len(tools)} Custom Tools")
def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
if llm_upsert_requests:
logger.notice("Seeding LLMs")
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
update_default_provider(
provider_id=seeded_providers[0].id, db_session=db_session
)
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
if personas:
logger.notice("Seeding Personas")
for persona in personas:
if not persona.prompt_ids:
raise ValueError(
f"Invalid Persona with name {persona.name}; no prompts exist"
)
upsert_persona(
user=None, # Seeding is done as admin
name=persona.name,
description=persona.description,
num_chunks=(
persona.num_chunks if persona.num_chunks is not None else 0.0
),
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO,
prompt_ids=persona.prompt_ids,
document_set_ids=persona.document_set_ids,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
is_public=persona.is_public,
db_session=db_session,
tool_ids=persona.tool_ids,
display_priority=persona.display_priority,
)
def _seed_settings(settings: Settings) -> None:
logger.notice("Seeding Settings")
try:
store_base_settings(settings)
logger.notice("Successfully seeded Settings")
except ValueError as e:
logger.error(f"Failed to seed Settings: {str(e)}")
def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
if (
seed_config.enterprise_settings is not None
or seed_config.nav_item_overrides is not None
):
final_enterprise_settings = (
deepcopy(seed_config.enterprise_settings)
if seed_config.enterprise_settings
else EnterpriseSettings()
)
final_nav_items = final_enterprise_settings.custom_nav_items
if seed_config.nav_item_overrides is not None:
final_nav_items = []
for item in seed_config.nav_item_overrides:
with open(item.svg_path, "r") as file:
svg_content = file.read().strip()
final_nav_items.append(
NavigationItem(
link=item.link,
title=item.title,
svg_logo=svg_content,
)
)
final_enterprise_settings.custom_nav_items = final_nav_items
logger.notice("Seeding enterprise settings")
store_ee_settings(final_enterprise_settings)
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
if logo_path:
logger.notice("Uploading logo")
upload_logo(db_session=db_session, file=logo_path)
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
custom_analytics_secret_key = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")
if seed_config.analytics_script_path and custom_analytics_secret_key:
logger.notice("Seeding analytics script")
try:
with open(seed_config.analytics_script_path, "r") as file:
script_content = file.read()
analytics_script = AnalyticsScriptUpload(
script=script_content, secret_key=custom_analytics_secret_key
)
store_analytics_script(analytics_script)
except FileNotFoundError:
logger.error(
f"Analytics script file not found: {seed_config.analytics_script_path}"
)
except ValueError as e:
logger.error(f"Failed to seed analytics script: {str(e)}")
def get_seed_config() -> SeedConfiguration | None:
return _parse_env()
def seed_db() -> None:
seed_config = _parse_env()
if seed_config is None:
logger.debug("No seeding configuration file passed")
return
with get_session_context_manager() as db_session:
if seed_config.llms is not None:
_seed_llms(db_session, seed_config.llms)
if seed_config.personas is not None:
_seed_personas(db_session, seed_config.personas)
if seed_config.settings is not None:
_seed_settings(seed_config.settings)
if seed_config.custom_tools is not None:
_seed_custom_tools(db_session, seed_config.custom_tools)
_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)
_seed_analytics_script(seed_config)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)