Tenant provisioning in the dataplane (#2694)

* add tenant provisioning to data plane

* minor typing update

* ensure tenant router included

* proper auth check

* update disabling logic

* validated basic provisioning

* use new kv store
This commit is contained in:
pablodanswer 2024-10-05 21:08:35 -07:00 committed by GitHub
parent e00f4678df
commit 0da736bed9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 615 additions and 423 deletions

View File

@ -8,6 +8,7 @@ from email.mime.text import MIMEText
from typing import Optional
from typing import Tuple
import jwt
from email_validator import EmailNotValidError
from email_validator import validate_email
from fastapi import APIRouter
@ -37,8 +38,10 @@ from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SMTP_PASS
@ -504,3 +507,28 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Danswer MIT
return []
async def control_plane_dep(request: Request) -> None:
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")

View File

@ -6,7 +6,6 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
@ -18,30 +17,32 @@ from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
def load_prompts_from_yaml(
db_session: Session, prompts_yaml: str = PROMPTS_YAML
) -> None:
with open(prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_prompts = data.get("prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
def load_personas_from_yaml(
db_session: Session,
personas_yaml: str = PERSONAS_YAML,
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> None:
@ -49,117 +50,117 @@ def load_personas_from_yaml(
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona).filter(Persona.name == persona["name"]).first()
)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
def load_input_prompts_from_yaml(
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
db_session: Session,
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)
load_input_prompts_from_yaml(input_prompts_yaml)
load_prompts_from_yaml(db_session, prompt_yaml)
load_personas_from_yaml(db_session, personas_yaml)
load_input_prompts_from_yaml(db_session, input_prompts_yaml)

View File

@ -416,3 +416,7 @@ ENTERPRISE_EDITION_ENABLED = (
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")

View File

@ -117,7 +117,7 @@ def get_db_current_time(db_session: Session) -> datetime:
# Regular expression to validate schema names to prevent SQL injection
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
def is_valid_schema_name(name: str) -> bool:
@ -281,6 +281,7 @@ def get_session_with_tenant(tenant_id: str | None = None) -> Session:
tenant_id = current_tenant_id.get()
if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
raise Exception("Invalid tenant ID")
engine = SqlEngine.get_engine()

View File

@ -5,7 +5,7 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_factory
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import KVStore
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KeyValueStore
@ -26,12 +26,9 @@ class PgRedisKVStore(KeyValueStore):
@contextmanager
def get_session(self) -> Iterator[Session]:
factory = get_session_factory()
session: Session = factory()
try:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
yield session
finally:
session.close()
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# Not encrypted in Redis, but encrypted in Postgres

View File

@ -1,4 +1,3 @@
import time
import traceback
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
@ -23,13 +22,11 @@ from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.chat.load_yamls import load_chat_yamls
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
@ -38,42 +35,9 @@ from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.connector import check_connectors_exist
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.credentials import create_initial_public_credential
from danswer.db.document import check_docs_exist
from danswer.db.engine import SqlEngine
from danswer.db.engine import warm_up_connections
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.llm import fetch_default_provider
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.persona import delete_old_default_personas
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.indexing.models import IndexingSetting
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router
@ -99,7 +63,6 @@ from danswer.server.manage.embedding.api import basic_router as embedding_router
from danswer.server.manage.get_state import router as state_router
from danswer.server.manage.llm.api import admin_router as llm_admin_router
from danswer.server.manage.llm.api import basic_router as llm_router
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.manage.search_settings import router as search_settings_router
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
@ -111,15 +74,10 @@ from danswer.server.query_and_chat.query_backend import (
from danswer.server.query_and_chat.query_backend import basic_router as query_router
from danswer.server.settings.api import admin_router as settings_admin_router
from danswer.server.settings.api import basic_router as settings_router
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
from danswer.tools.built_in_tools import load_builtin_tools
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
from danswer.utils.gpu_utils import gpu_status_request
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import get_or_generate_uuid
from danswer.utils.telemetry import optional_telemetry
@ -128,8 +86,6 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
@ -182,176 +138,6 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs)
def setup_postgres(db_session: Session) -> None:
logger.notice("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.notice("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
# Only for dev flows
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
model_req = LLMProviderUpsertRequest(
name="DevEnvPresetOpenAI",
provider="openai",
api_key=GEN_AI_API_KEY,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
fast_default_model_name=fast_model,
is_public=True,
groups=[],
display_model_names=[llm_model, fast_model],
model_names=[llm_model, fast_model],
)
new_llm_provider = upsert_llm_provider(
llm_provider=model_req, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
def update_default_multipass_indexing(db_session: Session) -> None:
docs_exist = check_docs_exist(db_session)
connectors_exist = check_connectors_exist(db_session)
logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}")
if not docs_exist and not connectors_exist:
logger.info(
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
)
gpu_available = gpu_status_request()
logger.info(f"GPU available: {gpu_available}")
current_settings = get_current_search_settings(db_session)
logger.notice(f"Updating multipass indexing setting to: {gpu_available}")
updated_settings = SavedSearchSettings.from_db_model(current_settings)
# Enable multipass indexing if GPU is available or if using a cloud provider
updated_settings.multipass_indexing = (
gpu_available or current_settings.cloud_provider is not None
)
update_current_search_settings(db_session, updated_settings)
# Update settings with GPU availability
settings = load_settings()
settings.gpu_enabled = gpu_available
store_settings(settings)
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
else:
logger.debug(
"Existing docs or connectors found. Skipping multipass indexing update."
)
def translate_saved_search_settings(db_session: Session) -> None:
kv_store = get_kv_store()
try:
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
if isinstance(search_settings_dict, dict):
# Update current search settings
current_settings = get_current_search_settings(db_session)
# Update non-preserved fields
if current_settings:
current_settings_dict = SavedSearchSettings.from_db_model(
current_settings
).dict()
new_current_settings = SavedSearchSettings(
**{**current_settings_dict, **search_settings_dict}
)
update_current_search_settings(db_session, new_current_settings)
# Update secondary search settings
secondary_settings = get_secondary_search_settings(db_session)
if secondary_settings:
secondary_settings_dict = SavedSearchSettings.from_db_model(
secondary_settings
).dict()
new_secondary_settings = SavedSearchSettings(
**{**secondary_settings_dict, **search_settings_dict}
)
update_secondary_search_settings(
db_session,
new_secondary_settings,
)
# Delete the KV store entry after successful update
kv_store.delete(KV_SEARCH_SETTINGS)
logger.notice("Search settings updated and KV store entry deleted.")
else:
logger.notice("KV store search settings is empty.")
except KvKeyNotFoundError:
logger.notice("No search config found in KV store.")
def mark_reindex_flag(db_session: Session) -> None:
kv_store = get_kv_store()
try:
value = kv_store.load(KV_REINDEX_KEY)
logger.debug(f"Re-indexing flag has value {value}")
return
except KvKeyNotFoundError:
# Only need to update the flag if it hasn't been set
pass
# If their first deployment is after the changes, it will
# enable this when the other changes go in, need to avoid
# this being set to False, then the user indexes things on the old version
docs_exist = check_docs_exist(db_session)
connectors_exist = check_connectors_exist(db_session)
if docs_exist or connectors_exist:
kv_store.store(KV_REINDEX_KEY, True)
else:
kv_store.store(KV_REINDEX_KEY, False)
def setup_vespa(
document_index: DocumentIndex,
index_setting: IndexingSetting,
secondary_index_setting: IndexingSetting | None,
) -> bool:
# Vespa startup is a bit slow, so give it a few seconds
WAIT_SECONDS = 5
VESPA_ATTEMPTS = 5
for x in range(VESPA_ATTEMPTS):
try:
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
document_index.ensure_indices_exist(
index_embedding_dim=index_setting.model_dim,
secondary_index_embedding_dim=secondary_index_setting.model_dim
if secondary_index_setting
else None,
)
logger.notice("Vespa setup complete.")
return True
except Exception:
logger.notice(
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
)
time.sleep(WAIT_SECONDS)
logger.error(
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
)
return False
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
SqlEngine.set_app_name(POSTGRES_WEB_APP_NAME)
@ -380,89 +166,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
get_or_generate_uuid()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
# Break bad state for thrashing indexes
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
if search_settings.query_prefix or search_settings.passage_prefix:
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
logger.notice(
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
)
if search_settings:
if not search_settings.disable_rerank_for_streaming:
logger.notice("Reranking is enabled.")
if search_settings.multilingual_expansion:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
if (
search_settings.rerank_model_name
and not search_settings.provider_type
and not search_settings.rerank_provider_type
):
warm_up_cross_encoder(search_settings.rerank_model_name)
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
# setup Postgres with default credential, llm providers, etc.
setup_postgres(db_session)
translate_saved_search_settings(db_session)
# Does the user need to trigger a reindexing to bring the document index
# into a good state, marked in the kv store
mark_reindex_flag(db_session)
# ensure Vespa is setup correctly
logger.notice("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=secondary_search_settings.index_name
if secondary_search_settings
else None,
)
success = setup_vespa(
document_index,
IndexingSetting.from_db_model(search_settings),
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None,
)
if not success:
raise RuntimeError(
"Could not connect to Vespa within the specified timeout."
)
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if search_settings.provider_type is None:
warm_up_bi_encoder(
embedding_model=EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
),
)
# update multipass indexing setting based on GPU availability
update_default_multipass_indexing(db_session)
setup_danswer(db_session)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield

View File

@ -4,6 +4,7 @@ from fastapi import FastAPI
from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute
from danswer.auth.users import control_plane_dep
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
@ -98,6 +99,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == control_plane_dep
):
found_auth = True
break

303
backend/danswer/setup.py Normal file
View File

@ -0,0 +1,303 @@
import time
from sqlalchemy.orm import Session
from danswer.chat.load_yamls import load_chat_yamls
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.connector import check_connectors_exist
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.credentials import create_initial_public_credential
from danswer.db.document import check_docs_exist
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.llm import fetch_default_provider
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.persona import delete_old_default_personas
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.indexing.models import IndexingSetting
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
from danswer.tools.built_in_tools import load_builtin_tools
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
from danswer.utils.gpu_utils import gpu_status_request
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
def setup_danswer(db_session: Session) -> None:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
# Break bad state for thrashing indexes
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
if search_settings.query_prefix or search_settings.passage_prefix:
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
logger.notice(f'Passage embedding prefix: "{search_settings.passage_prefix}"')
if search_settings:
if not search_settings.disable_rerank_for_streaming:
logger.notice("Reranking is enabled.")
if search_settings.multilingual_expansion:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
if (
search_settings.rerank_model_name
and not search_settings.provider_type
and not search_settings.rerank_provider_type
):
warm_up_cross_encoder(search_settings.rerank_model_name)
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
# setup Postgres with default credential, llm providers, etc.
setup_postgres(db_session)
translate_saved_search_settings(db_session)
# Does the user need to trigger a reindexing to bring the document index
# into a good state, marked in the kv store
mark_reindex_flag(db_session)
# ensure Vespa is setup correctly
logger.notice("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=secondary_search_settings.index_name
if secondary_search_settings
else None,
)
success = setup_vespa(
document_index,
IndexingSetting.from_db_model(search_settings),
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None,
)
if not success:
raise RuntimeError("Could not connect to Vespa within the specified timeout.")
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if search_settings.provider_type is None:
warm_up_bi_encoder(
embedding_model=EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
),
)
# update multipass indexing setting based on GPU availability
update_default_multipass_indexing(db_session)
def translate_saved_search_settings(db_session: Session) -> None:
kv_store = get_kv_store()
try:
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
if isinstance(search_settings_dict, dict):
# Update current search settings
current_settings = get_current_search_settings(db_session)
# Update non-preserved fields
if current_settings:
current_settings_dict = SavedSearchSettings.from_db_model(
current_settings
).dict()
new_current_settings = SavedSearchSettings(
**{**current_settings_dict, **search_settings_dict}
)
update_current_search_settings(db_session, new_current_settings)
# Update secondary search settings
secondary_settings = get_secondary_search_settings(db_session)
if secondary_settings:
secondary_settings_dict = SavedSearchSettings.from_db_model(
secondary_settings
).dict()
new_secondary_settings = SavedSearchSettings(
**{**secondary_settings_dict, **search_settings_dict}
)
update_secondary_search_settings(
db_session,
new_secondary_settings,
)
# Delete the KV store entry after successful update
kv_store.delete(KV_SEARCH_SETTINGS)
logger.notice("Search settings updated and KV store entry deleted.")
else:
logger.notice("KV store search settings is empty.")
except KvKeyNotFoundError:
logger.notice("No search config found in KV store.")
def mark_reindex_flag(db_session: Session) -> None:
kv_store = get_kv_store()
try:
value = kv_store.load(KV_REINDEX_KEY)
logger.debug(f"Re-indexing flag has value {value}")
return
except KvKeyNotFoundError:
# Only need to update the flag if it hasn't been set
pass
# If their first deployment is after the changes, it will
# enable this when the other changes go in, need to avoid
# this being set to False, then the user indexes things on the old version
docs_exist = check_docs_exist(db_session)
connectors_exist = check_connectors_exist(db_session)
if docs_exist or connectors_exist:
kv_store.store(KV_REINDEX_KEY, True)
else:
kv_store.store(KV_REINDEX_KEY, False)
def setup_vespa(
document_index: DocumentIndex,
index_setting: IndexingSetting,
secondary_index_setting: IndexingSetting | None,
) -> bool:
# Vespa startup is a bit slow, so give it a few seconds
WAIT_SECONDS = 5
VESPA_ATTEMPTS = 5
for x in range(VESPA_ATTEMPTS):
try:
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
document_index.ensure_indices_exist(
index_embedding_dim=index_setting.model_dim,
secondary_index_embedding_dim=secondary_index_setting.model_dim
if secondary_index_setting
else None,
)
logger.notice("Vespa setup complete.")
return True
except Exception:
logger.notice(
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
)
time.sleep(WAIT_SECONDS)
logger.error(
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
)
return False
def setup_postgres(db_session: Session) -> None:
logger.notice("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls(db_session)
logger.notice("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
# Only for dev flows
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
model_req = LLMProviderUpsertRequest(
name="DevEnvPresetOpenAI",
provider="openai",
api_key=GEN_AI_API_KEY,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
fast_default_model_name=fast_model,
is_public=True,
groups=[],
display_model_names=[llm_model, fast_model],
model_names=[llm_model, fast_model],
)
new_llm_provider = upsert_llm_provider(
llm_provider=model_req, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
def update_default_multipass_indexing(db_session: Session) -> None:
docs_exist = check_docs_exist(db_session)
connectors_exist = check_connectors_exist(db_session)
logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}")
if not docs_exist and not connectors_exist:
logger.info(
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
)
gpu_available = gpu_status_request()
logger.info(f"GPU available: {gpu_available}")
current_settings = get_current_search_settings(db_session)
logger.notice(f"Updating multipass indexing setting to: {gpu_available}")
updated_settings = SavedSearchSettings.from_db_model(current_settings)
# Enable multipass indexing if GPU is available or if using a cloud provider
updated_settings.multipass_indexing = (
gpu_available or current_settings.cloud_provider is not None
)
update_current_search_settings(db_session, updated_settings)
# Update settings with GPU availability
settings = load_settings()
settings.gpu_enabled = gpu_available
store_settings(settings)
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
else:
logger.debug(
"Existing docs or connectors found. Skipping multipass indexing update."
)

View File

@ -34,6 +34,7 @@ from ee.danswer.server.query_history.api import router as query_history_router
from ee.danswer.server.reporting.usage_export_api import router as usage_export_router
from ee.danswer.server.saml import router as saml_router
from ee.danswer.server.seeding import seed_db
from ee.danswer.server.tenants.api import router as tenants_router
from ee.danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
@ -79,6 +80,8 @@ def get_application() -> FastAPI:
# RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router)
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)
# Analytics endpoints
include_router_with_global_prefix_prepended(application, analytics_router)
include_router_with_global_prefix_prepended(application, query_history_router)

View File

@ -0,0 +1,46 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from danswer.auth.users import control_plane_dep
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import get_session_with_tenant
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/create")
def create_tenant(
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
) -> dict[str, str]:
try:
tenant_id = create_tenant_request.tenant_id
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
if not ensure_schema_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")
run_alembic_migrations(tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session)
logger.info(f"Tenant {tenant_id} created successfully")
return {
"status": "success",
"message": f"Tenant {tenant_id} created successfully",
}
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)

View File

@ -0,0 +1,6 @@
from pydantic import BaseModel
class CreateTenantRequest(BaseModel):
tenant_id: str
initial_admin_email: str

View File

@ -0,0 +1,63 @@
import os
from types import SimpleNamespace
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
from danswer.utils.logger import setup_logger
logger = setup_logger()
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
alembic_cfg.set_main_option(
"script_location", os.path.join(root_dir, "alembic")
)
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")
# Run migrations programmatically
logger.info(
f"Alembic migrations completed successfully for schema: {schema_name}"
)
except Exception as e:
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
raise
def ensure_schema_exists(tenant_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():
result = db_session.execute(
text(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
),
{"schema_name": tenant_id},
)
schema_exists = result.scalar() is not None
if not schema_exists:
stmt = CreateSchema(tenant_id)
db_session.execute(stmt)
return True
return False

View File

@ -18,8 +18,8 @@ from danswer.db.swap_index import check_index_swap
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa.index import VespaIndex
from danswer.indexing.models import IndexingSetting
from danswer.main import setup_postgres
from danswer.main import setup_vespa
from danswer.setup import setup_postgres
from danswer.setup import setup_vespa
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@ -14,6 +14,7 @@ import {
ValidQuestionResponse,
Relevance,
SearchDanswerDocument,
SourceMetadata,
} from "@/lib/search/interfaces";
import { searchRequestStreamed } from "@/lib/search/streamingQa";
import { CancellationToken, cancellable } from "@/lib/search/cancellable";
@ -40,6 +41,9 @@ import { ApiKeyModal } from "../llm/ApiKeyModal";
import { useSearchContext } from "../context/SearchContext";
import { useUser } from "../user/UserProvider";
import UnconfiguredProviderText from "../chat_search/UnconfiguredProviderText";
import { DateRangePickerValue } from "@tremor/react";
import { Tag } from "@/lib/types";
import { isEqual } from "lodash";
export type searchState =
| "input"
@ -370,8 +374,36 @@ export const SearchSection = ({
setSearchAnswerExpanded(false);
};
const [previousSearch, setPreviousSearch] = useState<string>("");
interface SearchDetails {
query: string;
sources: SourceMetadata[];
agentic: boolean;
documentSets: string[];
timeRange: DateRangePickerValue | null;
tags: Tag[];
persona: Persona;
}
const [previousSearch, setPreviousSearch] = useState<null | SearchDetails>(
null
);
const [agenticResults, setAgenticResults] = useState<boolean | null>(null);
const currentSearch = (overrideMessage?: string): SearchDetails => {
return {
query: overrideMessage || query,
sources: filterManager.selectedSources,
agentic: agentic!,
documentSets: filterManager.selectedDocumentSets,
timeRange: filterManager.timeRange,
tags: filterManager.selectedTags,
persona: assistants.find(
(assistant) => assistant.id === selectedPersona
) as Persona,
};
};
const isSearchChanged = () => {
return !isEqual(currentSearch(), previousSearch);
};
let lastSearchCancellationToken = useRef<CancellationToken | null>(null);
const onSearch = async ({
@ -394,7 +426,9 @@ export const SearchSection = ({
setIsFetching(true);
setSearchResponse(initialSearchResponse);
setPreviousSearch(overrideMessage || query);
setPreviousSearch(currentSearch(overrideMessage));
const searchFnArgs = {
query: overrideMessage || query,
sources: filterManager.selectedSources,
@ -761,7 +795,7 @@ export const SearchSection = ({
/>
<FullSearchBar
disabled={previousSearch === query}
disabled={!isSearchChanged()}
toggleAgentic={
disabledAgentic ? undefined : toggleAgentic
}