mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Tenant provisioning in the dataplane (#2694)
* add tenant provisioning to data plane * minor typing update * ensure tenant router included * proper auth check * update disabling logic * validated basic provisioning * use new kv store
This commit is contained in:
parent
e00f4678df
commit
0da736bed9
@ -8,6 +8,7 @@ from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
@ -37,8 +38,10 @@ from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DATA_PLANE_SECRET
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import EXPECTED_API_KEY
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SMTP_PASS
|
||||
@ -504,3 +507,28 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Danswer MIT
|
||||
return []
|
||||
|
||||
|
||||
async def control_plane_dep(request: Request) -> None:
|
||||
api_key = request.headers.get("X-API-KEY")
|
||||
if api_key != EXPECTED_API_KEY:
|
||||
logger.warning("Invalid API key")
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
logger.warning("Invalid authorization header")
|
||||
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||
|
||||
token = auth_header.split(" ")[1]
|
||||
try:
|
||||
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
|
||||
if payload.get("scope") != "tenant:create":
|
||||
logger.warning("Insufficient permissions")
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
raise HTTPException(status_code=401, detail="Token has expired")
|
||||
except jwt.InvalidTokenError:
|
||||
logger.warning("Invalid token")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
@ -6,7 +6,6 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Persona
|
||||
@ -18,30 +17,32 @@ from danswer.db.persona import upsert_prompt
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
||||
def load_prompts_from_yaml(
|
||||
db_session: Session, prompts_yaml: str = PROMPTS_YAML
|
||||
) -> None:
|
||||
with open(prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_prompts = data.get("prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
user=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
description=prompt["description"].strip(),
|
||||
system_prompt=prompt["system"].strip(),
|
||||
task_prompt=prompt["task"].strip(),
|
||||
include_citations=prompt["include_citations"],
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
user=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
description=prompt["description"].strip(),
|
||||
system_prompt=prompt["system"].strip(),
|
||||
task_prompt=prompt["task"].strip(),
|
||||
include_citations=prompt["include_citations"],
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_personas_from_yaml(
|
||||
db_session: Session,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
) -> None:
|
||||
@ -49,117 +50,117 @@ def load_personas_from_yaml(
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_personas = data.get("personas", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
doc_set_ids = None
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
doc_set_ids = None
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
.filter(ToolDBModel.name == "ImageGenerationTool")
|
||||
.first()
|
||||
)
|
||||
if image_gen_tool:
|
||||
tool_ids.append(image_gen_tool.id)
|
||||
|
||||
llm_model_provider_override = persona.get("llm_model_provider_override")
|
||||
llm_model_version_override = persona.get("llm_model_version_override")
|
||||
|
||||
# Set specific overrides for image generation persona
|
||||
if persona.get("image_generation"):
|
||||
llm_model_version_override = "gpt-4o"
|
||||
|
||||
existing_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.name == persona["name"])
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
.filter(ToolDBModel.name == "ImageGenerationTool")
|
||||
.first()
|
||||
)
|
||||
if image_gen_tool:
|
||||
tool_ids.append(image_gen_tool.id)
|
||||
|
||||
upsert_persona(
|
||||
user=None,
|
||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
num_chunks=persona.get("num_chunks")
|
||||
if persona.get("num_chunks") is not None
|
||||
else default_chunks,
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
builtin_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
db_session=db_session,
|
||||
)
|
||||
llm_model_provider_override = persona.get("llm_model_provider_override")
|
||||
llm_model_version_override = persona.get("llm_model_version_override")
|
||||
|
||||
# Set specific overrides for image generation persona
|
||||
if persona.get("image_generation"):
|
||||
llm_model_version_override = "gpt-4o"
|
||||
|
||||
existing_persona = (
|
||||
db_session.query(Persona).filter(Persona.name == persona["name"]).first()
|
||||
)
|
||||
|
||||
upsert_persona(
|
||||
user=None,
|
||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
num_chunks=persona.get("num_chunks")
|
||||
if persona.get("num_chunks") is not None
|
||||
else default_chunks,
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
builtin_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
|
||||
def load_input_prompts_from_yaml(
|
||||
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
|
||||
) -> None:
|
||||
with open(input_prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_input_prompts = data.get("input_prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
db_session: Session,
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(prompt_yaml)
|
||||
load_personas_from_yaml(personas_yaml)
|
||||
load_input_prompts_from_yaml(input_prompts_yaml)
|
||||
load_prompts_from_yaml(db_session, prompt_yaml)
|
||||
load_personas_from_yaml(db_session, personas_yaml)
|
||||
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
|
||||
|
@ -416,3 +416,7 @@ ENTERPRISE_EDITION_ENABLED = (
|
||||
|
||||
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
||||
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
|
||||
|
||||
|
||||
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
|
||||
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
|
||||
|
@ -117,7 +117,7 @@ def get_db_current_time(db_session: Session) -> datetime:
|
||||
|
||||
|
||||
# Regular expression to validate schema names to prevent SQL injection
|
||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
|
||||
def is_valid_schema_name(name: str) -> bool:
|
||||
@ -281,6 +281,7 @@ def get_session_with_tenant(tenant_id: str | None = None) -> Session:
|
||||
tenant_id = current_tenant_id.get()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||
raise Exception("Invalid tenant ID")
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
|
@ -5,7 +5,7 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_factory
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import KVStore
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.key_value_store.interface import KeyValueStore
|
||||
@ -26,12 +26,9 @@ class PgRedisKVStore(KeyValueStore):
|
||||
|
||||
@contextmanager
|
||||
def get_session(self) -> Iterator[Session]:
|
||||
factory = get_session_factory()
|
||||
session: Session = factory()
|
||||
try:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
# Not encrypted in Redis, but encrypted in Postgres
|
||||
|
@ -1,4 +1,3 @@
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
@ -23,13 +22,11 @@ from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.chat.load_yamls import load_chat_yamls
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@ -38,42 +35,9 @@ from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.db.connector import check_connectors_exist
|
||||
from danswer.db.connector import create_initial_default_connector
|
||||
from danswer.db.connector_credential_pair import associate_default_cc_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.document import check_docs_exist
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.engine import warm_up_connections
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.llm import fetch_default_provider
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.persona import delete_old_default_personas
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_current_search_settings
|
||||
from danswer.db.search_settings import update_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.auth_check import check_router_auth
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||
@ -99,7 +63,6 @@ from danswer.server.manage.embedding.api import basic_router as embedding_router
|
||||
from danswer.server.manage.get_state import router as state_router
|
||||
from danswer.server.manage.llm.api import admin_router as llm_admin_router
|
||||
from danswer.server.manage.llm.api import basic_router as llm_router
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.manage.search_settings import router as search_settings_router
|
||||
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from danswer.server.manage.users import router as user_router
|
||||
@ -111,15 +74,10 @@ from danswer.server.query_and_chat.query_backend import (
|
||||
from danswer.server.query_and_chat.query_backend import basic_router as query_router
|
||||
from danswer.server.settings.api import admin_router as settings_admin_router
|
||||
from danswer.server.settings.api import basic_router as settings_router
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
|
||||
from danswer.tools.built_in_tools import load_builtin_tools
|
||||
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
|
||||
from danswer.utils.gpu_utils import gpu_status_request
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import get_or_generate_uuid
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
@ -128,8 +86,6 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -182,176 +138,6 @@ def include_router_with_global_prefix_prepended(
|
||||
application.include_router(router, **final_kwargs)
|
||||
|
||||
|
||||
def setup_postgres(db_session: Session) -> None:
|
||||
logger.notice("Verifying default connector/credential exist.")
|
||||
create_initial_public_credential(db_session)
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls()
|
||||
|
||||
logger.notice("Loading built-in tools")
|
||||
load_builtin_tools(db_session)
|
||||
refresh_built_in_tools_cache(db_session)
|
||||
auto_add_search_tool_to_personas(db_session)
|
||||
|
||||
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
|
||||
# Only for dev flows
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
model_req = LLMProviderUpsertRequest(
|
||||
name="DevEnvPresetOpenAI",
|
||||
provider="openai",
|
||||
api_key=GEN_AI_API_KEY,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
fast_default_model_name=fast_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
display_model_names=[llm_model, fast_model],
|
||||
model_names=[llm_model, fast_model],
|
||||
)
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
docs_exist = check_docs_exist(db_session)
|
||||
connectors_exist = check_connectors_exist(db_session)
|
||||
logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}")
|
||||
|
||||
if not docs_exist and not connectors_exist:
|
||||
logger.info(
|
||||
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
|
||||
)
|
||||
gpu_available = gpu_status_request()
|
||||
logger.info(f"GPU available: {gpu_available}")
|
||||
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
|
||||
logger.notice(f"Updating multipass indexing setting to: {gpu_available}")
|
||||
updated_settings = SavedSearchSettings.from_db_model(current_settings)
|
||||
# Enable multipass indexing if GPU is available or if using a cloud provider
|
||||
updated_settings.multipass_indexing = (
|
||||
gpu_available or current_settings.cloud_provider is not None
|
||||
)
|
||||
update_current_search_settings(db_session, updated_settings)
|
||||
|
||||
# Update settings with GPU availability
|
||||
settings = load_settings()
|
||||
settings.gpu_enabled = gpu_available
|
||||
store_settings(settings)
|
||||
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"Existing docs or connectors found. Skipping multipass indexing update."
|
||||
)
|
||||
|
||||
|
||||
def translate_saved_search_settings(db_session: Session) -> None:
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
|
||||
if isinstance(search_settings_dict, dict):
|
||||
# Update current search settings
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
|
||||
# Update non-preserved fields
|
||||
if current_settings:
|
||||
current_settings_dict = SavedSearchSettings.from_db_model(
|
||||
current_settings
|
||||
).dict()
|
||||
|
||||
new_current_settings = SavedSearchSettings(
|
||||
**{**current_settings_dict, **search_settings_dict}
|
||||
)
|
||||
update_current_search_settings(db_session, new_current_settings)
|
||||
|
||||
# Update secondary search settings
|
||||
secondary_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_settings:
|
||||
secondary_settings_dict = SavedSearchSettings.from_db_model(
|
||||
secondary_settings
|
||||
).dict()
|
||||
|
||||
new_secondary_settings = SavedSearchSettings(
|
||||
**{**secondary_settings_dict, **search_settings_dict}
|
||||
)
|
||||
update_secondary_search_settings(
|
||||
db_session,
|
||||
new_secondary_settings,
|
||||
)
|
||||
# Delete the KV store entry after successful update
|
||||
kv_store.delete(KV_SEARCH_SETTINGS)
|
||||
logger.notice("Search settings updated and KV store entry deleted.")
|
||||
else:
|
||||
logger.notice("KV store search settings is empty.")
|
||||
except KvKeyNotFoundError:
|
||||
logger.notice("No search config found in KV store.")
|
||||
|
||||
|
||||
def mark_reindex_flag(db_session: Session) -> None:
|
||||
kv_store = get_kv_store()
|
||||
try:
|
||||
value = kv_store.load(KV_REINDEX_KEY)
|
||||
logger.debug(f"Re-indexing flag has value {value}")
|
||||
return
|
||||
except KvKeyNotFoundError:
|
||||
# Only need to update the flag if it hasn't been set
|
||||
pass
|
||||
|
||||
# If their first deployment is after the changes, it will
|
||||
# enable this when the other changes go in, need to avoid
|
||||
# this being set to False, then the user indexes things on the old version
|
||||
docs_exist = check_docs_exist(db_session)
|
||||
connectors_exist = check_connectors_exist(db_session)
|
||||
if docs_exist or connectors_exist:
|
||||
kv_store.store(KV_REINDEX_KEY, True)
|
||||
else:
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
|
||||
def setup_vespa(
|
||||
document_index: DocumentIndex,
|
||||
index_setting: IndexingSetting,
|
||||
secondary_index_setting: IndexingSetting | None,
|
||||
) -> bool:
|
||||
# Vespa startup is a bit slow, so give it a few seconds
|
||||
WAIT_SECONDS = 5
|
||||
VESPA_ATTEMPTS = 5
|
||||
for x in range(VESPA_ATTEMPTS):
|
||||
try:
|
||||
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=index_setting.model_dim,
|
||||
secondary_index_embedding_dim=secondary_index_setting.model_dim
|
||||
if secondary_index_setting
|
||||
else None,
|
||||
)
|
||||
|
||||
logger.notice("Vespa setup complete.")
|
||||
return True
|
||||
except Exception:
|
||||
logger.notice(
|
||||
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
|
||||
)
|
||||
time.sleep(WAIT_SECONDS)
|
||||
|
||||
logger.error(
|
||||
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
SqlEngine.set_app_name(POSTGRES_WEB_APP_NAME)
|
||||
@ -380,89 +166,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
get_or_generate_uuid()
|
||||
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
# Break bad state for thrashing indexes
|
||||
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
# Expire all old embedding models indexing attempts, technically redundant
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
|
||||
if search_settings.query_prefix or search_settings.passage_prefix:
|
||||
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
|
||||
logger.notice(
|
||||
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
|
||||
)
|
||||
|
||||
if search_settings:
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
logger.notice("Reranking is enabled.")
|
||||
|
||||
if search_settings.multilingual_expansion:
|
||||
logger.notice(
|
||||
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
|
||||
)
|
||||
if (
|
||||
search_settings.rerank_model_name
|
||||
and not search_settings.provider_type
|
||||
and not search_settings.rerank_provider_type
|
||||
):
|
||||
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||
|
||||
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
download_nltk_data()
|
||||
|
||||
# setup Postgres with default credential, llm providers, etc.
|
||||
setup_postgres(db_session)
|
||||
|
||||
translate_saved_search_settings(db_session)
|
||||
|
||||
# Does the user need to trigger a reindexing to bring the document index
|
||||
# into a good state, marked in the kv store
|
||||
mark_reindex_flag(db_session)
|
||||
|
||||
# ensure Vespa is setup correctly
|
||||
logger.notice("Verifying Document Index(s) is/are available.")
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
|
||||
success = setup_vespa(
|
||||
document_index,
|
||||
IndexingSetting.from_db_model(search_settings),
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError(
|
||||
"Could not connect to Vespa within the specified timeout."
|
||||
)
|
||||
|
||||
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
if search_settings.provider_type is None:
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
),
|
||||
)
|
||||
|
||||
# update multipass indexing setting based on GPU availability
|
||||
update_default_multipass_indexing(db_session)
|
||||
setup_danswer(db_session)
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
yield
|
||||
|
@ -4,6 +4,7 @@ from fastapi import FastAPI
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
from danswer.auth.users import control_plane_dep
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
@ -98,6 +99,7 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == api_key_dep
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == control_plane_dep
|
||||
):
|
||||
found_auth = True
|
||||
break
|
||||
|
303
backend/danswer/setup.py
Normal file
303
backend/danswer/setup.py
Normal file
@ -0,0 +1,303 @@
|
||||
import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.load_yamls import load_chat_yamls
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.db.connector import check_connectors_exist
|
||||
from danswer.db.connector import create_initial_default_connector
|
||||
from danswer.db.connector_credential_pair import associate_default_cc_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.document import check_docs_exist
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.llm import fetch_default_provider
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.persona import delete_old_default_personas
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_current_search_settings
|
||||
from danswer.db.search_settings import update_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
|
||||
from danswer.tools.built_in_tools import load_builtin_tools
|
||||
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
|
||||
from danswer.utils.gpu_utils import gpu_status_request
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def setup_danswer(db_session: Session) -> None:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
# Break bad state for thrashing indexes
|
||||
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
# Expire all old embedding models indexing attempts, technically redundant
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
|
||||
if search_settings.query_prefix or search_settings.passage_prefix:
|
||||
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
|
||||
logger.notice(f'Passage embedding prefix: "{search_settings.passage_prefix}"')
|
||||
|
||||
if search_settings:
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
logger.notice("Reranking is enabled.")
|
||||
|
||||
if search_settings.multilingual_expansion:
|
||||
logger.notice(
|
||||
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
|
||||
)
|
||||
if (
|
||||
search_settings.rerank_model_name
|
||||
and not search_settings.provider_type
|
||||
and not search_settings.rerank_provider_type
|
||||
):
|
||||
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||
|
||||
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
download_nltk_data()
|
||||
|
||||
# setup Postgres with default credential, llm providers, etc.
|
||||
setup_postgres(db_session)
|
||||
|
||||
translate_saved_search_settings(db_session)
|
||||
|
||||
# Does the user need to trigger a reindexing to bring the document index
|
||||
# into a good state, marked in the kv store
|
||||
mark_reindex_flag(db_session)
|
||||
|
||||
# ensure Vespa is setup correctly
|
||||
logger.notice("Verifying Document Index(s) is/are available.")
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
|
||||
success = setup_vespa(
|
||||
document_index,
|
||||
IndexingSetting.from_db_model(search_settings),
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError("Could not connect to Vespa within the specified timeout.")
|
||||
|
||||
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
if search_settings.provider_type is None:
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
),
|
||||
)
|
||||
|
||||
# update multipass indexing setting based on GPU availability
|
||||
update_default_multipass_indexing(db_session)
|
||||
|
||||
|
||||
def translate_saved_search_settings(db_session: Session) -> None:
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
|
||||
if isinstance(search_settings_dict, dict):
|
||||
# Update current search settings
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
|
||||
# Update non-preserved fields
|
||||
if current_settings:
|
||||
current_settings_dict = SavedSearchSettings.from_db_model(
|
||||
current_settings
|
||||
).dict()
|
||||
|
||||
new_current_settings = SavedSearchSettings(
|
||||
**{**current_settings_dict, **search_settings_dict}
|
||||
)
|
||||
update_current_search_settings(db_session, new_current_settings)
|
||||
|
||||
# Update secondary search settings
|
||||
secondary_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_settings:
|
||||
secondary_settings_dict = SavedSearchSettings.from_db_model(
|
||||
secondary_settings
|
||||
).dict()
|
||||
|
||||
new_secondary_settings = SavedSearchSettings(
|
||||
**{**secondary_settings_dict, **search_settings_dict}
|
||||
)
|
||||
update_secondary_search_settings(
|
||||
db_session,
|
||||
new_secondary_settings,
|
||||
)
|
||||
# Delete the KV store entry after successful update
|
||||
kv_store.delete(KV_SEARCH_SETTINGS)
|
||||
logger.notice("Search settings updated and KV store entry deleted.")
|
||||
else:
|
||||
logger.notice("KV store search settings is empty.")
|
||||
except KvKeyNotFoundError:
|
||||
logger.notice("No search config found in KV store.")
|
||||
|
||||
|
||||
def mark_reindex_flag(db_session: Session) -> None:
|
||||
kv_store = get_kv_store()
|
||||
try:
|
||||
value = kv_store.load(KV_REINDEX_KEY)
|
||||
logger.debug(f"Re-indexing flag has value {value}")
|
||||
return
|
||||
except KvKeyNotFoundError:
|
||||
# Only need to update the flag if it hasn't been set
|
||||
pass
|
||||
|
||||
# If their first deployment is after the changes, it will
|
||||
# enable this when the other changes go in, need to avoid
|
||||
# this being set to False, then the user indexes things on the old version
|
||||
docs_exist = check_docs_exist(db_session)
|
||||
connectors_exist = check_connectors_exist(db_session)
|
||||
if docs_exist or connectors_exist:
|
||||
kv_store.store(KV_REINDEX_KEY, True)
|
||||
else:
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
|
||||
def setup_vespa(
|
||||
document_index: DocumentIndex,
|
||||
index_setting: IndexingSetting,
|
||||
secondary_index_setting: IndexingSetting | None,
|
||||
) -> bool:
|
||||
# Vespa startup is a bit slow, so give it a few seconds
|
||||
WAIT_SECONDS = 5
|
||||
VESPA_ATTEMPTS = 5
|
||||
for x in range(VESPA_ATTEMPTS):
|
||||
try:
|
||||
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=index_setting.model_dim,
|
||||
secondary_index_embedding_dim=secondary_index_setting.model_dim
|
||||
if secondary_index_setting
|
||||
else None,
|
||||
)
|
||||
|
||||
logger.notice("Vespa setup complete.")
|
||||
return True
|
||||
except Exception:
|
||||
logger.notice(
|
||||
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
|
||||
)
|
||||
time.sleep(WAIT_SECONDS)
|
||||
|
||||
logger.error(
|
||||
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def setup_postgres(db_session: Session) -> None:
|
||||
logger.notice("Verifying default connector/credential exist.")
|
||||
create_initial_public_credential(db_session)
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls(db_session)
|
||||
|
||||
logger.notice("Loading built-in tools")
|
||||
load_builtin_tools(db_session)
|
||||
refresh_built_in_tools_cache(db_session)
|
||||
auto_add_search_tool_to_personas(db_session)
|
||||
|
||||
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
|
||||
# Only for dev flows
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
model_req = LLMProviderUpsertRequest(
|
||||
name="DevEnvPresetOpenAI",
|
||||
provider="openai",
|
||||
api_key=GEN_AI_API_KEY,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
fast_default_model_name=fast_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
display_model_names=[llm_model, fast_model],
|
||||
model_names=[llm_model, fast_model],
|
||||
)
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
docs_exist = check_docs_exist(db_session)
|
||||
connectors_exist = check_connectors_exist(db_session)
|
||||
logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}")
|
||||
|
||||
if not docs_exist and not connectors_exist:
|
||||
logger.info(
|
||||
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
|
||||
)
|
||||
gpu_available = gpu_status_request()
|
||||
logger.info(f"GPU available: {gpu_available}")
|
||||
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
|
||||
logger.notice(f"Updating multipass indexing setting to: {gpu_available}")
|
||||
updated_settings = SavedSearchSettings.from_db_model(current_settings)
|
||||
# Enable multipass indexing if GPU is available or if using a cloud provider
|
||||
updated_settings.multipass_indexing = (
|
||||
gpu_available or current_settings.cloud_provider is not None
|
||||
)
|
||||
update_current_search_settings(db_session, updated_settings)
|
||||
|
||||
# Update settings with GPU availability
|
||||
settings = load_settings()
|
||||
settings.gpu_enabled = gpu_available
|
||||
store_settings(settings)
|
||||
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"Existing docs or connectors found. Skipping multipass indexing update."
|
||||
)
|
@ -34,6 +34,7 @@ from ee.danswer.server.query_history.api import router as query_history_router
|
||||
from ee.danswer.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.danswer.server.saml import router as saml_router
|
||||
from ee.danswer.server.seeding import seed_db
|
||||
from ee.danswer.server.tenants.api import router as tenants_router
|
||||
from ee.danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
@ -79,6 +80,8 @@ def get_application() -> FastAPI:
|
||||
|
||||
# RBAC / group access control
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
# Analytics endpoints
|
||||
include_router_with_global_prefix_prepended(application, analytics_router)
|
||||
include_router_with_global_prefix_prepended(application, query_history_router)
|
||||
|
0
backend/ee/danswer/server/tenants/__init__.py
Normal file
0
backend/ee/danswer/server/tenants/__init__.py
Normal file
0
backend/ee/danswer/server/tenants/access.py
Normal file
0
backend/ee/danswer/server/tenants/access.py
Normal file
46
backend/ee/danswer/server/tenants/api.py
Normal file
46
backend/ee/danswer/server/tenants/api.py
Normal file
@ -0,0 +1,46 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from danswer.auth.users import control_plane_dep
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.server.tenants.models import CreateTenantRequest
|
||||
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
|
||||
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
def create_tenant(
|
||||
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
tenant_id = create_tenant_request.tenant_id
|
||||
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
if not ensure_schema_exists(tenant_id):
|
||||
logger.info(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
run_alembic_migrations(tenant_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session)
|
||||
|
||||
logger.info(f"Tenant {tenant_id} created successfully")
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Tenant {tenant_id} created successfully",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
6
backend/ee/danswer/server/tenants/models.py
Normal file
6
backend/ee/danswer/server/tenants/models.py
Normal file
@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CreateTenantRequest(BaseModel):
|
||||
tenant_id: str
|
||||
initial_admin_email: str
|
63
backend/ee/danswer/server/tenants/provisioning.py
Normal file
63
backend/ee/danswer/server/tenants/provisioning.py
Normal file
@ -0,0 +1,63 @@
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_alembic_migrations(schema_name: str) -> None:
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
|
||||
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
|
||||
|
||||
# Configure Alembic
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
|
||||
alembic_cfg.set_main_option(
|
||||
"script_location", os.path.join(root_dir, "alembic")
|
||||
)
|
||||
|
||||
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
# Run migrations programmatically
|
||||
logger.info(
|
||||
f"Alembic migrations completed successfully for schema: {schema_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def ensure_schema_exists(tenant_id: str) -> bool:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with db_session.begin():
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
|
||||
),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
schema_exists = result.scalar() is not None
|
||||
if not schema_exists:
|
||||
stmt = CreateSchema(tenant_id)
|
||||
db_session.execute(stmt)
|
||||
return True
|
||||
return False
|
@ -18,8 +18,8 @@ from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.main import setup_postgres
|
||||
from danswer.main import setup_vespa
|
||||
from danswer.setup import setup_postgres
|
||||
from danswer.setup import setup_vespa
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
@ -14,6 +14,7 @@ import {
|
||||
ValidQuestionResponse,
|
||||
Relevance,
|
||||
SearchDanswerDocument,
|
||||
SourceMetadata,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { searchRequestStreamed } from "@/lib/search/streamingQa";
|
||||
import { CancellationToken, cancellable } from "@/lib/search/cancellable";
|
||||
@ -40,6 +41,9 @@ import { ApiKeyModal } from "../llm/ApiKeyModal";
|
||||
import { useSearchContext } from "../context/SearchContext";
|
||||
import { useUser } from "../user/UserProvider";
|
||||
import UnconfiguredProviderText from "../chat_search/UnconfiguredProviderText";
|
||||
import { DateRangePickerValue } from "@tremor/react";
|
||||
import { Tag } from "@/lib/types";
|
||||
import { isEqual } from "lodash";
|
||||
|
||||
export type searchState =
|
||||
| "input"
|
||||
@ -370,8 +374,36 @@ export const SearchSection = ({
|
||||
setSearchAnswerExpanded(false);
|
||||
};
|
||||
|
||||
const [previousSearch, setPreviousSearch] = useState<string>("");
|
||||
interface SearchDetails {
|
||||
query: string;
|
||||
sources: SourceMetadata[];
|
||||
agentic: boolean;
|
||||
documentSets: string[];
|
||||
timeRange: DateRangePickerValue | null;
|
||||
tags: Tag[];
|
||||
persona: Persona;
|
||||
}
|
||||
|
||||
const [previousSearch, setPreviousSearch] = useState<null | SearchDetails>(
|
||||
null
|
||||
);
|
||||
const [agenticResults, setAgenticResults] = useState<boolean | null>(null);
|
||||
const currentSearch = (overrideMessage?: string): SearchDetails => {
|
||||
return {
|
||||
query: overrideMessage || query,
|
||||
sources: filterManager.selectedSources,
|
||||
agentic: agentic!,
|
||||
documentSets: filterManager.selectedDocumentSets,
|
||||
timeRange: filterManager.timeRange,
|
||||
tags: filterManager.selectedTags,
|
||||
persona: assistants.find(
|
||||
(assistant) => assistant.id === selectedPersona
|
||||
) as Persona,
|
||||
};
|
||||
};
|
||||
const isSearchChanged = () => {
|
||||
return !isEqual(currentSearch(), previousSearch);
|
||||
};
|
||||
|
||||
let lastSearchCancellationToken = useRef<CancellationToken | null>(null);
|
||||
const onSearch = async ({
|
||||
@ -394,7 +426,9 @@ export const SearchSection = ({
|
||||
|
||||
setIsFetching(true);
|
||||
setSearchResponse(initialSearchResponse);
|
||||
setPreviousSearch(overrideMessage || query);
|
||||
|
||||
setPreviousSearch(currentSearch(overrideMessage));
|
||||
|
||||
const searchFnArgs = {
|
||||
query: overrideMessage || query,
|
||||
sources: filterManager.selectedSources,
|
||||
@ -761,7 +795,7 @@ export const SearchSection = ({
|
||||
/>
|
||||
|
||||
<FullSearchBar
|
||||
disabled={previousSearch === query}
|
||||
disabled={!isSearchChanged()}
|
||||
toggleAgentic={
|
||||
disabledAgentic ? undefined : toggleAgentic
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user