mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-19 12:30:55 +02:00
Tenant provisioning in the dataplane (#2694)
* add tenant provisioning to data plane * minor typing update * ensure tenant router included * proper auth check * update disabling logic * validated basic provisioning * use new kv store
This commit is contained in:
parent
e00f4678df
commit
0da736bed9
@ -8,6 +8,7 @@ from email.mime.text import MIMEText
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import jwt
|
||||||
from email_validator import EmailNotValidError
|
from email_validator import EmailNotValidError
|
||||||
from email_validator import validate_email
|
from email_validator import validate_email
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
@ -37,8 +38,10 @@ from danswer.auth.schemas import UserCreate
|
|||||||
from danswer.auth.schemas import UserRole
|
from danswer.auth.schemas import UserRole
|
||||||
from danswer.auth.schemas import UserUpdate
|
from danswer.auth.schemas import UserUpdate
|
||||||
from danswer.configs.app_configs import AUTH_TYPE
|
from danswer.configs.app_configs import AUTH_TYPE
|
||||||
|
from danswer.configs.app_configs import DATA_PLANE_SECRET
|
||||||
from danswer.configs.app_configs import DISABLE_AUTH
|
from danswer.configs.app_configs import DISABLE_AUTH
|
||||||
from danswer.configs.app_configs import EMAIL_FROM
|
from danswer.configs.app_configs import EMAIL_FROM
|
||||||
|
from danswer.configs.app_configs import EXPECTED_API_KEY
|
||||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||||
from danswer.configs.app_configs import SMTP_PASS
|
from danswer.configs.app_configs import SMTP_PASS
|
||||||
@ -504,3 +507,28 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
|||||||
def get_default_admin_user_emails_() -> list[str]:
|
def get_default_admin_user_emails_() -> list[str]:
|
||||||
# No default seeding available for Danswer MIT
|
# No default seeding available for Danswer MIT
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def control_plane_dep(request: Request) -> None:
|
||||||
|
api_key = request.headers.get("X-API-KEY")
|
||||||
|
if api_key != EXPECTED_API_KEY:
|
||||||
|
logger.warning("Invalid API key")
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if not auth_header or not auth_header.startswith("Bearer "):
|
||||||
|
logger.warning("Invalid authorization header")
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||||
|
|
||||||
|
token = auth_header.split(" ")[1]
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
|
||||||
|
if payload.get("scope") != "tenant:create":
|
||||||
|
logger.warning("Insufficient permissions")
|
||||||
|
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
logger.warning("Token has expired")
|
||||||
|
raise HTTPException(status_code=401, detail="Token has expired")
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
logger.warning("Invalid token")
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
@ -6,7 +6,6 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
|||||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
|
||||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||||
from danswer.db.models import Persona
|
from danswer.db.models import Persona
|
||||||
@ -18,30 +17,32 @@ from danswer.db.persona import upsert_prompt
|
|||||||
from danswer.search.enums import RecencyBiasSetting
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
|
|
||||||
|
|
||||||
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
def load_prompts_from_yaml(
|
||||||
|
db_session: Session, prompts_yaml: str = PROMPTS_YAML
|
||||||
|
) -> None:
|
||||||
with open(prompts_yaml, "r") as file:
|
with open(prompts_yaml, "r") as file:
|
||||||
data = yaml.safe_load(file)
|
data = yaml.safe_load(file)
|
||||||
|
|
||||||
all_prompts = data.get("prompts", [])
|
all_prompts = data.get("prompts", [])
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
for prompt in all_prompts:
|
||||||
for prompt in all_prompts:
|
upsert_prompt(
|
||||||
upsert_prompt(
|
user=None,
|
||||||
user=None,
|
prompt_id=prompt.get("id"),
|
||||||
prompt_id=prompt.get("id"),
|
name=prompt["name"],
|
||||||
name=prompt["name"],
|
description=prompt["description"].strip(),
|
||||||
description=prompt["description"].strip(),
|
system_prompt=prompt["system"].strip(),
|
||||||
system_prompt=prompt["system"].strip(),
|
task_prompt=prompt["task"].strip(),
|
||||||
task_prompt=prompt["task"].strip(),
|
include_citations=prompt["include_citations"],
|
||||||
include_citations=prompt["include_citations"],
|
datetime_aware=prompt.get("datetime_aware", True),
|
||||||
datetime_aware=prompt.get("datetime_aware", True),
|
default_prompt=True,
|
||||||
default_prompt=True,
|
personas=None,
|
||||||
personas=None,
|
db_session=db_session,
|
||||||
db_session=db_session,
|
commit=True,
|
||||||
commit=True,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_personas_from_yaml(
|
def load_personas_from_yaml(
|
||||||
|
db_session: Session,
|
||||||
personas_yaml: str = PERSONAS_YAML,
|
personas_yaml: str = PERSONAS_YAML,
|
||||||
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -49,117 +50,117 @@ def load_personas_from_yaml(
|
|||||||
data = yaml.safe_load(file)
|
data = yaml.safe_load(file)
|
||||||
|
|
||||||
all_personas = data.get("personas", [])
|
all_personas = data.get("personas", [])
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
for persona in all_personas:
|
||||||
for persona in all_personas:
|
doc_set_names = persona["document_sets"]
|
||||||
doc_set_names = persona["document_sets"]
|
doc_sets: list[DocumentSetDBModel] = [
|
||||||
doc_sets: list[DocumentSetDBModel] = [
|
get_or_create_document_set_by_name(db_session, name)
|
||||||
get_or_create_document_set_by_name(db_session, name)
|
for name in doc_set_names
|
||||||
for name in doc_set_names
|
]
|
||||||
|
|
||||||
|
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||||
|
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||||
|
# the document sets for the persona
|
||||||
|
doc_set_ids: list[int] | None = None
|
||||||
|
if doc_sets:
|
||||||
|
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||||
|
else:
|
||||||
|
doc_set_ids = None
|
||||||
|
|
||||||
|
prompt_ids: list[int] | None = None
|
||||||
|
prompt_set_names = persona["prompts"]
|
||||||
|
if prompt_set_names:
|
||||||
|
prompts: list[PromptDBModel | None] = [
|
||||||
|
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||||
|
for prompt_name in prompt_set_names
|
||||||
]
|
]
|
||||||
|
if any([prompt is None for prompt in prompts]):
|
||||||
|
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||||
|
|
||||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
if prompts:
|
||||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||||
# the document sets for the persona
|
|
||||||
doc_set_ids: list[int] | None = None
|
|
||||||
if doc_sets:
|
|
||||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
|
||||||
else:
|
|
||||||
doc_set_ids = None
|
|
||||||
|
|
||||||
prompt_ids: list[int] | None = None
|
p_id = persona.get("id")
|
||||||
prompt_set_names = persona["prompts"]
|
tool_ids = []
|
||||||
if prompt_set_names:
|
if persona.get("image_generation"):
|
||||||
prompts: list[PromptDBModel | None] = [
|
image_gen_tool = (
|
||||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
db_session.query(ToolDBModel)
|
||||||
for prompt_name in prompt_set_names
|
.filter(ToolDBModel.name == "ImageGenerationTool")
|
||||||
]
|
|
||||||
if any([prompt is None for prompt in prompts]):
|
|
||||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
|
||||||
|
|
||||||
if prompts:
|
|
||||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
|
||||||
|
|
||||||
p_id = persona.get("id")
|
|
||||||
tool_ids = []
|
|
||||||
if persona.get("image_generation"):
|
|
||||||
image_gen_tool = (
|
|
||||||
db_session.query(ToolDBModel)
|
|
||||||
.filter(ToolDBModel.name == "ImageGenerationTool")
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if image_gen_tool:
|
|
||||||
tool_ids.append(image_gen_tool.id)
|
|
||||||
|
|
||||||
llm_model_provider_override = persona.get("llm_model_provider_override")
|
|
||||||
llm_model_version_override = persona.get("llm_model_version_override")
|
|
||||||
|
|
||||||
# Set specific overrides for image generation persona
|
|
||||||
if persona.get("image_generation"):
|
|
||||||
llm_model_version_override = "gpt-4o"
|
|
||||||
|
|
||||||
existing_persona = (
|
|
||||||
db_session.query(Persona)
|
|
||||||
.filter(Persona.name == persona["name"])
|
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
if image_gen_tool:
|
||||||
|
tool_ids.append(image_gen_tool.id)
|
||||||
|
|
||||||
upsert_persona(
|
llm_model_provider_override = persona.get("llm_model_provider_override")
|
||||||
user=None,
|
llm_model_version_override = persona.get("llm_model_version_override")
|
||||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
|
||||||
name=persona["name"],
|
# Set specific overrides for image generation persona
|
||||||
description=persona["description"],
|
if persona.get("image_generation"):
|
||||||
num_chunks=persona.get("num_chunks")
|
llm_model_version_override = "gpt-4o"
|
||||||
if persona.get("num_chunks") is not None
|
|
||||||
else default_chunks,
|
existing_persona = (
|
||||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
db_session.query(Persona).filter(Persona.name == persona["name"]).first()
|
||||||
starter_messages=persona.get("starter_messages"),
|
)
|
||||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
|
||||||
icon_shape=persona.get("icon_shape"),
|
upsert_persona(
|
||||||
icon_color=persona.get("icon_color"),
|
user=None,
|
||||||
llm_model_provider_override=llm_model_provider_override,
|
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||||
llm_model_version_override=llm_model_version_override,
|
name=persona["name"],
|
||||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
description=persona["description"],
|
||||||
prompt_ids=prompt_ids,
|
num_chunks=persona.get("num_chunks")
|
||||||
document_set_ids=doc_set_ids,
|
if persona.get("num_chunks") is not None
|
||||||
tool_ids=tool_ids,
|
else default_chunks,
|
||||||
builtin_persona=True,
|
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||||
is_public=True,
|
starter_messages=persona.get("starter_messages"),
|
||||||
display_priority=existing_persona.display_priority
|
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||||
if existing_persona is not None
|
icon_shape=persona.get("icon_shape"),
|
||||||
else persona.get("display_priority"),
|
icon_color=persona.get("icon_color"),
|
||||||
is_visible=existing_persona.is_visible
|
llm_model_provider_override=llm_model_provider_override,
|
||||||
if existing_persona is not None
|
llm_model_version_override=llm_model_version_override,
|
||||||
else persona.get("is_visible"),
|
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||||
db_session=db_session,
|
prompt_ids=prompt_ids,
|
||||||
)
|
document_set_ids=doc_set_ids,
|
||||||
|
tool_ids=tool_ids,
|
||||||
|
builtin_persona=True,
|
||||||
|
is_public=True,
|
||||||
|
display_priority=existing_persona.display_priority
|
||||||
|
if existing_persona is not None
|
||||||
|
else persona.get("display_priority"),
|
||||||
|
is_visible=existing_persona.is_visible
|
||||||
|
if existing_persona is not None
|
||||||
|
else persona.get("is_visible"),
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
|
def load_input_prompts_from_yaml(
|
||||||
|
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
|
||||||
|
) -> None:
|
||||||
with open(input_prompts_yaml, "r") as file:
|
with open(input_prompts_yaml, "r") as file:
|
||||||
data = yaml.safe_load(file)
|
data = yaml.safe_load(file)
|
||||||
|
|
||||||
all_input_prompts = data.get("input_prompts", [])
|
all_input_prompts = data.get("input_prompts", [])
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
for input_prompt in all_input_prompts:
|
||||||
for input_prompt in all_input_prompts:
|
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
|
||||||
insert_input_prompt_if_not_exists(
|
insert_input_prompt_if_not_exists(
|
||||||
user=None,
|
user=None,
|
||||||
input_prompt_id=input_prompt.get("id"),
|
input_prompt_id=input_prompt.get("id"),
|
||||||
prompt=input_prompt["prompt"],
|
prompt=input_prompt["prompt"],
|
||||||
content=input_prompt["content"],
|
content=input_prompt["content"],
|
||||||
is_public=input_prompt["is_public"],
|
is_public=input_prompt["is_public"],
|
||||||
active=input_prompt.get("active", True),
|
active=input_prompt.get("active", True),
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
commit=True,
|
commit=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_chat_yamls(
|
def load_chat_yamls(
|
||||||
|
db_session: Session,
|
||||||
prompt_yaml: str = PROMPTS_YAML,
|
prompt_yaml: str = PROMPTS_YAML,
|
||||||
personas_yaml: str = PERSONAS_YAML,
|
personas_yaml: str = PERSONAS_YAML,
|
||||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||||
) -> None:
|
) -> None:
|
||||||
load_prompts_from_yaml(prompt_yaml)
|
load_prompts_from_yaml(db_session, prompt_yaml)
|
||||||
load_personas_from_yaml(personas_yaml)
|
load_personas_from_yaml(db_session, personas_yaml)
|
||||||
load_input_prompts_from_yaml(input_prompts_yaml)
|
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
|
||||||
|
@ -416,3 +416,7 @@ ENTERPRISE_EDITION_ENABLED = (
|
|||||||
|
|
||||||
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
||||||
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
|
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
|
||||||
|
|
||||||
|
|
||||||
|
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
|
||||||
|
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
|
||||||
|
@ -117,7 +117,7 @@ def get_db_current_time(db_session: Session) -> datetime:
|
|||||||
|
|
||||||
|
|
||||||
# Regular expression to validate schema names to prevent SQL injection
|
# Regular expression to validate schema names to prevent SQL injection
|
||||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||||
|
|
||||||
|
|
||||||
def is_valid_schema_name(name: str) -> bool:
|
def is_valid_schema_name(name: str) -> bool:
|
||||||
@ -281,6 +281,7 @@ def get_session_with_tenant(tenant_id: str | None = None) -> Session:
|
|||||||
tenant_id = current_tenant_id.get()
|
tenant_id = current_tenant_id.get()
|
||||||
|
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
|
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||||
raise Exception("Invalid tenant ID")
|
raise Exception("Invalid tenant ID")
|
||||||
|
|
||||||
engine = SqlEngine.get_engine()
|
engine = SqlEngine.get_engine()
|
||||||
|
@ -5,7 +5,7 @@ from typing import cast
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.db.engine import get_session_factory
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.db.models import KVStore
|
from danswer.db.models import KVStore
|
||||||
from danswer.key_value_store.interface import JSON_ro
|
from danswer.key_value_store.interface import JSON_ro
|
||||||
from danswer.key_value_store.interface import KeyValueStore
|
from danswer.key_value_store.interface import KeyValueStore
|
||||||
@ -26,12 +26,9 @@ class PgRedisKVStore(KeyValueStore):
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_session(self) -> Iterator[Session]:
|
def get_session(self) -> Iterator[Session]:
|
||||||
factory = get_session_factory()
|
engine = get_sqlalchemy_engine()
|
||||||
session: Session = factory()
|
with Session(engine, expire_on_commit=False) as session:
|
||||||
try:
|
|
||||||
yield session
|
yield session
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||||
# Not encrypted in Redis, but encrypted in Postgres
|
# Not encrypted in Redis, but encrypted in Postgres
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@ -23,13 +22,11 @@ from danswer.auth.schemas import UserRead
|
|||||||
from danswer.auth.schemas import UserUpdate
|
from danswer.auth.schemas import UserUpdate
|
||||||
from danswer.auth.users import auth_backend
|
from danswer.auth.users import auth_backend
|
||||||
from danswer.auth.users import fastapi_users
|
from danswer.auth.users import fastapi_users
|
||||||
from danswer.chat.load_yamls import load_chat_yamls
|
|
||||||
from danswer.configs.app_configs import APP_API_PREFIX
|
from danswer.configs.app_configs import APP_API_PREFIX
|
||||||
from danswer.configs.app_configs import APP_HOST
|
from danswer.configs.app_configs import APP_HOST
|
||||||
from danswer.configs.app_configs import APP_PORT
|
from danswer.configs.app_configs import APP_PORT
|
||||||
from danswer.configs.app_configs import AUTH_TYPE
|
from danswer.configs.app_configs import AUTH_TYPE
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
|
||||||
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
|
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||||
@ -38,42 +35,9 @@ from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
|||||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||||
from danswer.configs.app_configs import WEB_DOMAIN
|
from danswer.configs.app_configs import WEB_DOMAIN
|
||||||
from danswer.configs.constants import AuthType
|
from danswer.configs.constants import AuthType
|
||||||
from danswer.configs.constants import KV_REINDEX_KEY
|
|
||||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
|
||||||
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
|
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
|
||||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
|
||||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
|
||||||
from danswer.db.connector import check_connectors_exist
|
|
||||||
from danswer.db.connector import create_initial_default_connector
|
|
||||||
from danswer.db.connector_credential_pair import associate_default_cc_pair
|
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
|
||||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
|
||||||
from danswer.db.credentials import create_initial_public_credential
|
|
||||||
from danswer.db.document import check_docs_exist
|
|
||||||
from danswer.db.engine import SqlEngine
|
from danswer.db.engine import SqlEngine
|
||||||
from danswer.db.engine import warm_up_connections
|
from danswer.db.engine import warm_up_connections
|
||||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
|
||||||
from danswer.db.index_attempt import expire_index_attempts
|
|
||||||
from danswer.db.llm import fetch_default_provider
|
|
||||||
from danswer.db.llm import update_default_provider
|
|
||||||
from danswer.db.llm import upsert_llm_provider
|
|
||||||
from danswer.db.persona import delete_old_default_personas
|
|
||||||
from danswer.db.search_settings import get_current_search_settings
|
|
||||||
from danswer.db.search_settings import get_secondary_search_settings
|
|
||||||
from danswer.db.search_settings import update_current_search_settings
|
|
||||||
from danswer.db.search_settings import update_secondary_search_settings
|
|
||||||
from danswer.db.swap_index import check_index_swap
|
|
||||||
from danswer.document_index.factory import get_default_document_index
|
|
||||||
from danswer.document_index.interfaces import DocumentIndex
|
|
||||||
from danswer.indexing.models import IndexingSetting
|
|
||||||
from danswer.key_value_store.factory import get_kv_store
|
|
||||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
|
||||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
|
||||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
|
||||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
|
||||||
from danswer.search.models import SavedSearchSettings
|
|
||||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
|
||||||
from danswer.server.auth_check import check_router_auth
|
from danswer.server.auth_check import check_router_auth
|
||||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||||
@ -99,7 +63,6 @@ from danswer.server.manage.embedding.api import basic_router as embedding_router
|
|||||||
from danswer.server.manage.get_state import router as state_router
|
from danswer.server.manage.get_state import router as state_router
|
||||||
from danswer.server.manage.llm.api import admin_router as llm_admin_router
|
from danswer.server.manage.llm.api import admin_router as llm_admin_router
|
||||||
from danswer.server.manage.llm.api import basic_router as llm_router
|
from danswer.server.manage.llm.api import basic_router as llm_router
|
||||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
|
||||||
from danswer.server.manage.search_settings import router as search_settings_router
|
from danswer.server.manage.search_settings import router as search_settings_router
|
||||||
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
||||||
from danswer.server.manage.users import router as user_router
|
from danswer.server.manage.users import router as user_router
|
||||||
@ -111,15 +74,10 @@ from danswer.server.query_and_chat.query_backend import (
|
|||||||
from danswer.server.query_and_chat.query_backend import basic_router as query_router
|
from danswer.server.query_and_chat.query_backend import basic_router as query_router
|
||||||
from danswer.server.settings.api import admin_router as settings_admin_router
|
from danswer.server.settings.api import admin_router as settings_admin_router
|
||||||
from danswer.server.settings.api import basic_router as settings_router
|
from danswer.server.settings.api import basic_router as settings_router
|
||||||
from danswer.server.settings.store import load_settings
|
|
||||||
from danswer.server.settings.store import store_settings
|
|
||||||
from danswer.server.token_rate_limits.api import (
|
from danswer.server.token_rate_limits.api import (
|
||||||
router as token_rate_limit_settings_router,
|
router as token_rate_limit_settings_router,
|
||||||
)
|
)
|
||||||
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
|
from danswer.setup import setup_danswer
|
||||||
from danswer.tools.built_in_tools import load_builtin_tools
|
|
||||||
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
|
|
||||||
from danswer.utils.gpu_utils import gpu_status_request
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.telemetry import get_or_generate_uuid
|
from danswer.utils.telemetry import get_or_generate_uuid
|
||||||
from danswer.utils.telemetry import optional_telemetry
|
from danswer.utils.telemetry import optional_telemetry
|
||||||
@ -128,8 +86,6 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
|
|||||||
from danswer.utils.variable_functionality import global_version
|
from danswer.utils.variable_functionality import global_version
|
||||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||||
from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
||||||
from shared_configs.configs import MODEL_SERVER_HOST
|
|
||||||
from shared_configs.configs import MODEL_SERVER_PORT
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@ -182,176 +138,6 @@ def include_router_with_global_prefix_prepended(
|
|||||||
application.include_router(router, **final_kwargs)
|
application.include_router(router, **final_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def setup_postgres(db_session: Session) -> None:
|
|
||||||
logger.notice("Verifying default connector/credential exist.")
|
|
||||||
create_initial_public_credential(db_session)
|
|
||||||
create_initial_default_connector(db_session)
|
|
||||||
associate_default_cc_pair(db_session)
|
|
||||||
|
|
||||||
logger.notice("Loading default Prompts and Personas")
|
|
||||||
delete_old_default_personas(db_session)
|
|
||||||
load_chat_yamls()
|
|
||||||
|
|
||||||
logger.notice("Loading built-in tools")
|
|
||||||
load_builtin_tools(db_session)
|
|
||||||
refresh_built_in_tools_cache(db_session)
|
|
||||||
auto_add_search_tool_to_personas(db_session)
|
|
||||||
|
|
||||||
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
|
|
||||||
# Only for dev flows
|
|
||||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
|
||||||
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
|
||||||
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
|
||||||
model_req = LLMProviderUpsertRequest(
|
|
||||||
name="DevEnvPresetOpenAI",
|
|
||||||
provider="openai",
|
|
||||||
api_key=GEN_AI_API_KEY,
|
|
||||||
api_base=None,
|
|
||||||
api_version=None,
|
|
||||||
custom_config=None,
|
|
||||||
default_model_name=llm_model,
|
|
||||||
fast_default_model_name=fast_model,
|
|
||||||
is_public=True,
|
|
||||||
groups=[],
|
|
||||||
display_model_names=[llm_model, fast_model],
|
|
||||||
model_names=[llm_model, fast_model],
|
|
||||||
)
|
|
||||||
new_llm_provider = upsert_llm_provider(
|
|
||||||
llm_provider=model_req, db_session=db_session
|
|
||||||
)
|
|
||||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
|
||||||
|
|
||||||
|
|
||||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
|
||||||
docs_exist = check_docs_exist(db_session)
|
|
||||||
connectors_exist = check_connectors_exist(db_session)
|
|
||||||
logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}")
|
|
||||||
|
|
||||||
if not docs_exist and not connectors_exist:
|
|
||||||
logger.info(
|
|
||||||
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
|
|
||||||
)
|
|
||||||
gpu_available = gpu_status_request()
|
|
||||||
logger.info(f"GPU available: {gpu_available}")
|
|
||||||
|
|
||||||
current_settings = get_current_search_settings(db_session)
|
|
||||||
|
|
||||||
logger.notice(f"Updating multipass indexing setting to: {gpu_available}")
|
|
||||||
updated_settings = SavedSearchSettings.from_db_model(current_settings)
|
|
||||||
# Enable multipass indexing if GPU is available or if using a cloud provider
|
|
||||||
updated_settings.multipass_indexing = (
|
|
||||||
gpu_available or current_settings.cloud_provider is not None
|
|
||||||
)
|
|
||||||
update_current_search_settings(db_session, updated_settings)
|
|
||||||
|
|
||||||
# Update settings with GPU availability
|
|
||||||
settings = load_settings()
|
|
||||||
settings.gpu_enabled = gpu_available
|
|
||||||
store_settings(settings)
|
|
||||||
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"Existing docs or connectors found. Skipping multipass indexing update."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def translate_saved_search_settings(db_session: Session) -> None:
|
|
||||||
kv_store = get_kv_store()
|
|
||||||
|
|
||||||
try:
|
|
||||||
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
|
|
||||||
if isinstance(search_settings_dict, dict):
|
|
||||||
# Update current search settings
|
|
||||||
current_settings = get_current_search_settings(db_session)
|
|
||||||
|
|
||||||
# Update non-preserved fields
|
|
||||||
if current_settings:
|
|
||||||
current_settings_dict = SavedSearchSettings.from_db_model(
|
|
||||||
current_settings
|
|
||||||
).dict()
|
|
||||||
|
|
||||||
new_current_settings = SavedSearchSettings(
|
|
||||||
**{**current_settings_dict, **search_settings_dict}
|
|
||||||
)
|
|
||||||
update_current_search_settings(db_session, new_current_settings)
|
|
||||||
|
|
||||||
# Update secondary search settings
|
|
||||||
secondary_settings = get_secondary_search_settings(db_session)
|
|
||||||
if secondary_settings:
|
|
||||||
secondary_settings_dict = SavedSearchSettings.from_db_model(
|
|
||||||
secondary_settings
|
|
||||||
).dict()
|
|
||||||
|
|
||||||
new_secondary_settings = SavedSearchSettings(
|
|
||||||
**{**secondary_settings_dict, **search_settings_dict}
|
|
||||||
)
|
|
||||||
update_secondary_search_settings(
|
|
||||||
db_session,
|
|
||||||
new_secondary_settings,
|
|
||||||
)
|
|
||||||
# Delete the KV store entry after successful update
|
|
||||||
kv_store.delete(KV_SEARCH_SETTINGS)
|
|
||||||
logger.notice("Search settings updated and KV store entry deleted.")
|
|
||||||
else:
|
|
||||||
logger.notice("KV store search settings is empty.")
|
|
||||||
except KvKeyNotFoundError:
|
|
||||||
logger.notice("No search config found in KV store.")
|
|
||||||
|
|
||||||
|
|
||||||
def mark_reindex_flag(db_session: Session) -> None:
|
|
||||||
kv_store = get_kv_store()
|
|
||||||
try:
|
|
||||||
value = kv_store.load(KV_REINDEX_KEY)
|
|
||||||
logger.debug(f"Re-indexing flag has value {value}")
|
|
||||||
return
|
|
||||||
except KvKeyNotFoundError:
|
|
||||||
# Only need to update the flag if it hasn't been set
|
|
||||||
pass
|
|
||||||
|
|
||||||
# If their first deployment is after the changes, it will
|
|
||||||
# enable this when the other changes go in, need to avoid
|
|
||||||
# this being set to False, then the user indexes things on the old version
|
|
||||||
docs_exist = check_docs_exist(db_session)
|
|
||||||
connectors_exist = check_connectors_exist(db_session)
|
|
||||||
if docs_exist or connectors_exist:
|
|
||||||
kv_store.store(KV_REINDEX_KEY, True)
|
|
||||||
else:
|
|
||||||
kv_store.store(KV_REINDEX_KEY, False)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_vespa(
|
|
||||||
document_index: DocumentIndex,
|
|
||||||
index_setting: IndexingSetting,
|
|
||||||
secondary_index_setting: IndexingSetting | None,
|
|
||||||
) -> bool:
|
|
||||||
# Vespa startup is a bit slow, so give it a few seconds
|
|
||||||
WAIT_SECONDS = 5
|
|
||||||
VESPA_ATTEMPTS = 5
|
|
||||||
for x in range(VESPA_ATTEMPTS):
|
|
||||||
try:
|
|
||||||
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
|
|
||||||
document_index.ensure_indices_exist(
|
|
||||||
index_embedding_dim=index_setting.model_dim,
|
|
||||||
secondary_index_embedding_dim=secondary_index_setting.model_dim
|
|
||||||
if secondary_index_setting
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.notice("Vespa setup complete.")
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.notice(
|
|
||||||
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
|
|
||||||
)
|
|
||||||
time.sleep(WAIT_SECONDS)
|
|
||||||
|
|
||||||
logger.error(
|
|
||||||
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||||
SqlEngine.set_app_name(POSTGRES_WEB_APP_NAME)
|
SqlEngine.set_app_name(POSTGRES_WEB_APP_NAME)
|
||||||
@ -380,89 +166,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
get_or_generate_uuid()
|
get_or_generate_uuid()
|
||||||
|
|
||||||
with Session(engine) as db_session:
|
with Session(engine) as db_session:
|
||||||
check_index_swap(db_session=db_session)
|
setup_danswer(db_session)
|
||||||
search_settings = get_current_search_settings(db_session)
|
|
||||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
|
||||||
|
|
||||||
# Break bad state for thrashing indexes
|
|
||||||
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
|
|
||||||
expire_index_attempts(
|
|
||||||
search_settings_id=search_settings.id, db_session=db_session
|
|
||||||
)
|
|
||||||
|
|
||||||
for cc_pair in get_connector_credential_pairs(db_session):
|
|
||||||
resync_cc_pair(cc_pair, db_session=db_session)
|
|
||||||
|
|
||||||
# Expire all old embedding models indexing attempts, technically redundant
|
|
||||||
cancel_indexing_attempts_past_model(db_session)
|
|
||||||
|
|
||||||
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
|
|
||||||
if search_settings.query_prefix or search_settings.passage_prefix:
|
|
||||||
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
|
|
||||||
logger.notice(
|
|
||||||
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
if search_settings:
|
|
||||||
if not search_settings.disable_rerank_for_streaming:
|
|
||||||
logger.notice("Reranking is enabled.")
|
|
||||||
|
|
||||||
if search_settings.multilingual_expansion:
|
|
||||||
logger.notice(
|
|
||||||
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
search_settings.rerank_model_name
|
|
||||||
and not search_settings.provider_type
|
|
||||||
and not search_settings.rerank_provider_type
|
|
||||||
):
|
|
||||||
warm_up_cross_encoder(search_settings.rerank_model_name)
|
|
||||||
|
|
||||||
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
|
||||||
download_nltk_data()
|
|
||||||
|
|
||||||
# setup Postgres with default credential, llm providers, etc.
|
|
||||||
setup_postgres(db_session)
|
|
||||||
|
|
||||||
translate_saved_search_settings(db_session)
|
|
||||||
|
|
||||||
# Does the user need to trigger a reindexing to bring the document index
|
|
||||||
# into a good state, marked in the kv store
|
|
||||||
mark_reindex_flag(db_session)
|
|
||||||
|
|
||||||
# ensure Vespa is setup correctly
|
|
||||||
logger.notice("Verifying Document Index(s) is/are available.")
|
|
||||||
document_index = get_default_document_index(
|
|
||||||
primary_index_name=search_settings.index_name,
|
|
||||||
secondary_index_name=secondary_search_settings.index_name
|
|
||||||
if secondary_search_settings
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
success = setup_vespa(
|
|
||||||
document_index,
|
|
||||||
IndexingSetting.from_db_model(search_settings),
|
|
||||||
IndexingSetting.from_db_model(secondary_search_settings)
|
|
||||||
if secondary_search_settings
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
if not success:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Could not connect to Vespa within the specified timeout."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
|
||||||
if search_settings.provider_type is None:
|
|
||||||
warm_up_bi_encoder(
|
|
||||||
embedding_model=EmbeddingModel.from_db_model(
|
|
||||||
search_settings=search_settings,
|
|
||||||
server_host=MODEL_SERVER_HOST,
|
|
||||||
server_port=MODEL_SERVER_PORT,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# update multipass indexing setting based on GPU availability
|
|
||||||
update_default_multipass_indexing(db_session)
|
|
||||||
|
|
||||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||||
yield
|
yield
|
||||||
|
@ -4,6 +4,7 @@ from fastapi import FastAPI
|
|||||||
from fastapi.dependencies.models import Dependant
|
from fastapi.dependencies.models import Dependant
|
||||||
from starlette.routing import BaseRoute
|
from starlette.routing import BaseRoute
|
||||||
|
|
||||||
|
from danswer.auth.users import control_plane_dep
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.auth.users import current_curator_or_admin_user
|
from danswer.auth.users import current_curator_or_admin_user
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
@ -98,6 +99,7 @@ def check_router_auth(
|
|||||||
or depends_fn == current_curator_or_admin_user
|
or depends_fn == current_curator_or_admin_user
|
||||||
or depends_fn == api_key_dep
|
or depends_fn == api_key_dep
|
||||||
or depends_fn == current_user_with_expired_token
|
or depends_fn == current_user_with_expired_token
|
||||||
|
or depends_fn == control_plane_dep
|
||||||
):
|
):
|
||||||
found_auth = True
|
found_auth = True
|
||||||
break
|
break
|
||||||
|
303
backend/danswer/setup.py
Normal file
303
backend/danswer/setup.py
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.chat.load_yamls import load_chat_yamls
|
||||||
|
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||||
|
from danswer.configs.constants import KV_REINDEX_KEY
|
||||||
|
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||||
|
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||||
|
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
|
from danswer.db.connector import check_connectors_exist
|
||||||
|
from danswer.db.connector import create_initial_default_connector
|
||||||
|
from danswer.db.connector_credential_pair import associate_default_cc_pair
|
||||||
|
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||||
|
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||||
|
from danswer.db.credentials import create_initial_public_credential
|
||||||
|
from danswer.db.document import check_docs_exist
|
||||||
|
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||||
|
from danswer.db.index_attempt import expire_index_attempts
|
||||||
|
from danswer.db.llm import fetch_default_provider
|
||||||
|
from danswer.db.llm import update_default_provider
|
||||||
|
from danswer.db.llm import upsert_llm_provider
|
||||||
|
from danswer.db.persona import delete_old_default_personas
|
||||||
|
from danswer.db.search_settings import get_current_search_settings
|
||||||
|
from danswer.db.search_settings import get_secondary_search_settings
|
||||||
|
from danswer.db.search_settings import update_current_search_settings
|
||||||
|
from danswer.db.search_settings import update_secondary_search_settings
|
||||||
|
from danswer.db.swap_index import check_index_swap
|
||||||
|
from danswer.document_index.factory import get_default_document_index
|
||||||
|
from danswer.document_index.interfaces import DocumentIndex
|
||||||
|
from danswer.indexing.models import IndexingSetting
|
||||||
|
from danswer.key_value_store.factory import get_kv_store
|
||||||
|
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||||
|
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||||
|
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||||
|
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||||
|
from danswer.search.models import SavedSearchSettings
|
||||||
|
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||||
|
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||||
|
from danswer.server.settings.store import load_settings
|
||||||
|
from danswer.server.settings.store import store_settings
|
||||||
|
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
|
||||||
|
from danswer.tools.built_in_tools import load_builtin_tools
|
||||||
|
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
|
||||||
|
from danswer.utils.gpu_utils import gpu_status_request
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from shared_configs.configs import MODEL_SERVER_HOST
|
||||||
|
from shared_configs.configs import MODEL_SERVER_PORT
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_danswer(db_session: Session) -> None:
|
||||||
|
check_index_swap(db_session=db_session)
|
||||||
|
search_settings = get_current_search_settings(db_session)
|
||||||
|
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||||
|
|
||||||
|
# Break bad state for thrashing indexes
|
||||||
|
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||||
|
expire_index_attempts(
|
||||||
|
search_settings_id=search_settings.id, db_session=db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
for cc_pair in get_connector_credential_pairs(db_session):
|
||||||
|
resync_cc_pair(cc_pair, db_session=db_session)
|
||||||
|
|
||||||
|
# Expire all old embedding models indexing attempts, technically redundant
|
||||||
|
cancel_indexing_attempts_past_model(db_session)
|
||||||
|
|
||||||
|
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
|
||||||
|
if search_settings.query_prefix or search_settings.passage_prefix:
|
||||||
|
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
|
||||||
|
logger.notice(f'Passage embedding prefix: "{search_settings.passage_prefix}"')
|
||||||
|
|
||||||
|
if search_settings:
|
||||||
|
if not search_settings.disable_rerank_for_streaming:
|
||||||
|
logger.notice("Reranking is enabled.")
|
||||||
|
|
||||||
|
if search_settings.multilingual_expansion:
|
||||||
|
logger.notice(
|
||||||
|
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
search_settings.rerank_model_name
|
||||||
|
and not search_settings.provider_type
|
||||||
|
and not search_settings.rerank_provider_type
|
||||||
|
):
|
||||||
|
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||||
|
|
||||||
|
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
|
download_nltk_data()
|
||||||
|
|
||||||
|
# setup Postgres with default credential, llm providers, etc.
|
||||||
|
setup_postgres(db_session)
|
||||||
|
|
||||||
|
translate_saved_search_settings(db_session)
|
||||||
|
|
||||||
|
# Does the user need to trigger a reindexing to bring the document index
|
||||||
|
# into a good state, marked in the kv store
|
||||||
|
mark_reindex_flag(db_session)
|
||||||
|
|
||||||
|
# ensure Vespa is setup correctly
|
||||||
|
logger.notice("Verifying Document Index(s) is/are available.")
|
||||||
|
document_index = get_default_document_index(
|
||||||
|
primary_index_name=search_settings.index_name,
|
||||||
|
secondary_index_name=secondary_search_settings.index_name
|
||||||
|
if secondary_search_settings
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
success = setup_vespa(
|
||||||
|
document_index,
|
||||||
|
IndexingSetting.from_db_model(search_settings),
|
||||||
|
IndexingSetting.from_db_model(secondary_search_settings)
|
||||||
|
if secondary_search_settings
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError("Could not connect to Vespa within the specified timeout.")
|
||||||
|
|
||||||
|
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||||
|
if search_settings.provider_type is None:
|
||||||
|
warm_up_bi_encoder(
|
||||||
|
embedding_model=EmbeddingModel.from_db_model(
|
||||||
|
search_settings=search_settings,
|
||||||
|
server_host=MODEL_SERVER_HOST,
|
||||||
|
server_port=MODEL_SERVER_PORT,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# update multipass indexing setting based on GPU availability
|
||||||
|
update_default_multipass_indexing(db_session)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_saved_search_settings(db_session: Session) -> None:
|
||||||
|
kv_store = get_kv_store()
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
|
||||||
|
if isinstance(search_settings_dict, dict):
|
||||||
|
# Update current search settings
|
||||||
|
current_settings = get_current_search_settings(db_session)
|
||||||
|
|
||||||
|
# Update non-preserved fields
|
||||||
|
if current_settings:
|
||||||
|
current_settings_dict = SavedSearchSettings.from_db_model(
|
||||||
|
current_settings
|
||||||
|
).dict()
|
||||||
|
|
||||||
|
new_current_settings = SavedSearchSettings(
|
||||||
|
**{**current_settings_dict, **search_settings_dict}
|
||||||
|
)
|
||||||
|
update_current_search_settings(db_session, new_current_settings)
|
||||||
|
|
||||||
|
# Update secondary search settings
|
||||||
|
secondary_settings = get_secondary_search_settings(db_session)
|
||||||
|
if secondary_settings:
|
||||||
|
secondary_settings_dict = SavedSearchSettings.from_db_model(
|
||||||
|
secondary_settings
|
||||||
|
).dict()
|
||||||
|
|
||||||
|
new_secondary_settings = SavedSearchSettings(
|
||||||
|
**{**secondary_settings_dict, **search_settings_dict}
|
||||||
|
)
|
||||||
|
update_secondary_search_settings(
|
||||||
|
db_session,
|
||||||
|
new_secondary_settings,
|
||||||
|
)
|
||||||
|
# Delete the KV store entry after successful update
|
||||||
|
kv_store.delete(KV_SEARCH_SETTINGS)
|
||||||
|
logger.notice("Search settings updated and KV store entry deleted.")
|
||||||
|
else:
|
||||||
|
logger.notice("KV store search settings is empty.")
|
||||||
|
except KvKeyNotFoundError:
|
||||||
|
logger.notice("No search config found in KV store.")
|
||||||
|
|
||||||
|
|
||||||
|
def mark_reindex_flag(db_session: Session) -> None:
|
||||||
|
kv_store = get_kv_store()
|
||||||
|
try:
|
||||||
|
value = kv_store.load(KV_REINDEX_KEY)
|
||||||
|
logger.debug(f"Re-indexing flag has value {value}")
|
||||||
|
return
|
||||||
|
except KvKeyNotFoundError:
|
||||||
|
# Only need to update the flag if it hasn't been set
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If their first deployment is after the changes, it will
|
||||||
|
# enable this when the other changes go in, need to avoid
|
||||||
|
# this being set to False, then the user indexes things on the old version
|
||||||
|
docs_exist = check_docs_exist(db_session)
|
||||||
|
connectors_exist = check_connectors_exist(db_session)
|
||||||
|
if docs_exist or connectors_exist:
|
||||||
|
kv_store.store(KV_REINDEX_KEY, True)
|
||||||
|
else:
|
||||||
|
kv_store.store(KV_REINDEX_KEY, False)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_vespa(
|
||||||
|
document_index: DocumentIndex,
|
||||||
|
index_setting: IndexingSetting,
|
||||||
|
secondary_index_setting: IndexingSetting | None,
|
||||||
|
) -> bool:
|
||||||
|
# Vespa startup is a bit slow, so give it a few seconds
|
||||||
|
WAIT_SECONDS = 5
|
||||||
|
VESPA_ATTEMPTS = 5
|
||||||
|
for x in range(VESPA_ATTEMPTS):
|
||||||
|
try:
|
||||||
|
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
|
||||||
|
document_index.ensure_indices_exist(
|
||||||
|
index_embedding_dim=index_setting.model_dim,
|
||||||
|
secondary_index_embedding_dim=secondary_index_setting.model_dim
|
||||||
|
if secondary_index_setting
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.notice("Vespa setup complete.")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.notice(
|
||||||
|
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
|
||||||
|
)
|
||||||
|
time.sleep(WAIT_SECONDS)
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def setup_postgres(db_session: Session) -> None:
|
||||||
|
logger.notice("Verifying default connector/credential exist.")
|
||||||
|
create_initial_public_credential(db_session)
|
||||||
|
create_initial_default_connector(db_session)
|
||||||
|
associate_default_cc_pair(db_session)
|
||||||
|
|
||||||
|
logger.notice("Loading default Prompts and Personas")
|
||||||
|
delete_old_default_personas(db_session)
|
||||||
|
load_chat_yamls(db_session)
|
||||||
|
|
||||||
|
logger.notice("Loading built-in tools")
|
||||||
|
load_builtin_tools(db_session)
|
||||||
|
refresh_built_in_tools_cache(db_session)
|
||||||
|
auto_add_search_tool_to_personas(db_session)
|
||||||
|
|
||||||
|
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
|
||||||
|
# Only for dev flows
|
||||||
|
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||||
|
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||||
|
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||||
|
model_req = LLMProviderUpsertRequest(
|
||||||
|
name="DevEnvPresetOpenAI",
|
||||||
|
provider="openai",
|
||||||
|
api_key=GEN_AI_API_KEY,
|
||||||
|
api_base=None,
|
||||||
|
api_version=None,
|
||||||
|
custom_config=None,
|
||||||
|
default_model_name=llm_model,
|
||||||
|
fast_default_model_name=fast_model,
|
||||||
|
is_public=True,
|
||||||
|
groups=[],
|
||||||
|
display_model_names=[llm_model, fast_model],
|
||||||
|
model_names=[llm_model, fast_model],
|
||||||
|
)
|
||||||
|
new_llm_provider = upsert_llm_provider(
|
||||||
|
llm_provider=model_req, db_session=db_session
|
||||||
|
)
|
||||||
|
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||||
|
|
||||||
|
|
||||||
|
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||||
|
docs_exist = check_docs_exist(db_session)
|
||||||
|
connectors_exist = check_connectors_exist(db_session)
|
||||||
|
logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}")
|
||||||
|
|
||||||
|
if not docs_exist and not connectors_exist:
|
||||||
|
logger.info(
|
||||||
|
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
|
||||||
|
)
|
||||||
|
gpu_available = gpu_status_request()
|
||||||
|
logger.info(f"GPU available: {gpu_available}")
|
||||||
|
|
||||||
|
current_settings = get_current_search_settings(db_session)
|
||||||
|
|
||||||
|
logger.notice(f"Updating multipass indexing setting to: {gpu_available}")
|
||||||
|
updated_settings = SavedSearchSettings.from_db_model(current_settings)
|
||||||
|
# Enable multipass indexing if GPU is available or if using a cloud provider
|
||||||
|
updated_settings.multipass_indexing = (
|
||||||
|
gpu_available or current_settings.cloud_provider is not None
|
||||||
|
)
|
||||||
|
update_current_search_settings(db_session, updated_settings)
|
||||||
|
|
||||||
|
# Update settings with GPU availability
|
||||||
|
settings = load_settings()
|
||||||
|
settings.gpu_enabled = gpu_available
|
||||||
|
store_settings(settings)
|
||||||
|
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Existing docs or connectors found. Skipping multipass indexing update."
|
||||||
|
)
|
@ -34,6 +34,7 @@ from ee.danswer.server.query_history.api import router as query_history_router
|
|||||||
from ee.danswer.server.reporting.usage_export_api import router as usage_export_router
|
from ee.danswer.server.reporting.usage_export_api import router as usage_export_router
|
||||||
from ee.danswer.server.saml import router as saml_router
|
from ee.danswer.server.saml import router as saml_router
|
||||||
from ee.danswer.server.seeding import seed_db
|
from ee.danswer.server.seeding import seed_db
|
||||||
|
from ee.danswer.server.tenants.api import router as tenants_router
|
||||||
from ee.danswer.server.token_rate_limits.api import (
|
from ee.danswer.server.token_rate_limits.api import (
|
||||||
router as token_rate_limit_settings_router,
|
router as token_rate_limit_settings_router,
|
||||||
)
|
)
|
||||||
@ -79,6 +80,8 @@ def get_application() -> FastAPI:
|
|||||||
|
|
||||||
# RBAC / group access control
|
# RBAC / group access control
|
||||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||||
|
# Tenant management
|
||||||
|
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||||
# Analytics endpoints
|
# Analytics endpoints
|
||||||
include_router_with_global_prefix_prepended(application, analytics_router)
|
include_router_with_global_prefix_prepended(application, analytics_router)
|
||||||
include_router_with_global_prefix_prepended(application, query_history_router)
|
include_router_with_global_prefix_prepended(application, query_history_router)
|
||||||
|
0
backend/ee/danswer/server/tenants/__init__.py
Normal file
0
backend/ee/danswer/server/tenants/__init__.py
Normal file
0
backend/ee/danswer/server/tenants/access.py
Normal file
0
backend/ee/danswer/server/tenants/access.py
Normal file
46
backend/ee/danswer/server/tenants/api.py
Normal file
46
backend/ee/danswer/server/tenants/api.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from danswer.auth.users import control_plane_dep
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
|
from danswer.db.engine import get_session_with_tenant
|
||||||
|
from danswer.setup import setup_danswer
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from ee.danswer.server.tenants.models import CreateTenantRequest
|
||||||
|
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
|
||||||
|
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
router = APIRouter(prefix="/tenants")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/create")
|
||||||
|
def create_tenant(
|
||||||
|
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
|
||||||
|
) -> dict[str, str]:
|
||||||
|
try:
|
||||||
|
tenant_id = create_tenant_request.tenant_id
|
||||||
|
|
||||||
|
if not MULTI_TENANT:
|
||||||
|
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||||
|
|
||||||
|
if not ensure_schema_exists(tenant_id):
|
||||||
|
logger.info(f"Created schema for tenant {tenant_id}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||||
|
|
||||||
|
run_alembic_migrations(tenant_id)
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
|
setup_danswer(db_session)
|
||||||
|
|
||||||
|
logger.info(f"Tenant {tenant_id} created successfully")
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"Tenant {tenant_id} created successfully",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||||
|
)
|
6
backend/ee/danswer/server/tenants/models.py
Normal file
6
backend/ee/danswer/server/tenants/models.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class CreateTenantRequest(BaseModel):
|
||||||
|
tenant_id: str
|
||||||
|
initial_admin_email: str
|
63
backend/ee/danswer/server/tenants/provisioning.py
Normal file
63
backend/ee/danswer/server/tenants/provisioning.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import os
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.schema import CreateSchema
|
||||||
|
|
||||||
|
from alembic import command
|
||||||
|
from alembic.config import Config
|
||||||
|
from danswer.db.engine import build_connection_string
|
||||||
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def run_alembic_migrations(schema_name: str) -> None:
|
||||||
|
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
|
||||||
|
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
|
||||||
|
|
||||||
|
# Configure Alembic
|
||||||
|
alembic_cfg = Config(alembic_ini_path)
|
||||||
|
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
|
||||||
|
alembic_cfg.set_main_option(
|
||||||
|
"script_location", os.path.join(root_dir, "alembic")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||||
|
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||||
|
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||||
|
|
||||||
|
# Run migrations programmatically
|
||||||
|
command.upgrade(alembic_cfg, "head")
|
||||||
|
|
||||||
|
# Run migrations programmatically
|
||||||
|
logger.info(
|
||||||
|
f"Alembic migrations completed successfully for schema: {schema_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_schema_exists(tenant_id: str) -> bool:
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
with db_session.begin():
|
||||||
|
result = db_session.execute(
|
||||||
|
text(
|
||||||
|
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
|
||||||
|
),
|
||||||
|
{"schema_name": tenant_id},
|
||||||
|
)
|
||||||
|
schema_exists = result.scalar() is not None
|
||||||
|
if not schema_exists:
|
||||||
|
stmt = CreateSchema(tenant_id)
|
||||||
|
db_session.execute(stmt)
|
||||||
|
return True
|
||||||
|
return False
|
@ -18,8 +18,8 @@ from danswer.db.swap_index import check_index_swap
|
|||||||
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||||
from danswer.document_index.vespa.index import VespaIndex
|
from danswer.document_index.vespa.index import VespaIndex
|
||||||
from danswer.indexing.models import IndexingSetting
|
from danswer.indexing.models import IndexingSetting
|
||||||
from danswer.main import setup_postgres
|
from danswer.setup import setup_postgres
|
||||||
from danswer.main import setup_vespa
|
from danswer.setup import setup_vespa
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
@ -14,6 +14,7 @@ import {
|
|||||||
ValidQuestionResponse,
|
ValidQuestionResponse,
|
||||||
Relevance,
|
Relevance,
|
||||||
SearchDanswerDocument,
|
SearchDanswerDocument,
|
||||||
|
SourceMetadata,
|
||||||
} from "@/lib/search/interfaces";
|
} from "@/lib/search/interfaces";
|
||||||
import { searchRequestStreamed } from "@/lib/search/streamingQa";
|
import { searchRequestStreamed } from "@/lib/search/streamingQa";
|
||||||
import { CancellationToken, cancellable } from "@/lib/search/cancellable";
|
import { CancellationToken, cancellable } from "@/lib/search/cancellable";
|
||||||
@ -40,6 +41,9 @@ import { ApiKeyModal } from "../llm/ApiKeyModal";
|
|||||||
import { useSearchContext } from "../context/SearchContext";
|
import { useSearchContext } from "../context/SearchContext";
|
||||||
import { useUser } from "../user/UserProvider";
|
import { useUser } from "../user/UserProvider";
|
||||||
import UnconfiguredProviderText from "../chat_search/UnconfiguredProviderText";
|
import UnconfiguredProviderText from "../chat_search/UnconfiguredProviderText";
|
||||||
|
import { DateRangePickerValue } from "@tremor/react";
|
||||||
|
import { Tag } from "@/lib/types";
|
||||||
|
import { isEqual } from "lodash";
|
||||||
|
|
||||||
export type searchState =
|
export type searchState =
|
||||||
| "input"
|
| "input"
|
||||||
@ -370,8 +374,36 @@ export const SearchSection = ({
|
|||||||
setSearchAnswerExpanded(false);
|
setSearchAnswerExpanded(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
const [previousSearch, setPreviousSearch] = useState<string>("");
|
interface SearchDetails {
|
||||||
|
query: string;
|
||||||
|
sources: SourceMetadata[];
|
||||||
|
agentic: boolean;
|
||||||
|
documentSets: string[];
|
||||||
|
timeRange: DateRangePickerValue | null;
|
||||||
|
tags: Tag[];
|
||||||
|
persona: Persona;
|
||||||
|
}
|
||||||
|
|
||||||
|
const [previousSearch, setPreviousSearch] = useState<null | SearchDetails>(
|
||||||
|
null
|
||||||
|
);
|
||||||
const [agenticResults, setAgenticResults] = useState<boolean | null>(null);
|
const [agenticResults, setAgenticResults] = useState<boolean | null>(null);
|
||||||
|
const currentSearch = (overrideMessage?: string): SearchDetails => {
|
||||||
|
return {
|
||||||
|
query: overrideMessage || query,
|
||||||
|
sources: filterManager.selectedSources,
|
||||||
|
agentic: agentic!,
|
||||||
|
documentSets: filterManager.selectedDocumentSets,
|
||||||
|
timeRange: filterManager.timeRange,
|
||||||
|
tags: filterManager.selectedTags,
|
||||||
|
persona: assistants.find(
|
||||||
|
(assistant) => assistant.id === selectedPersona
|
||||||
|
) as Persona,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
const isSearchChanged = () => {
|
||||||
|
return !isEqual(currentSearch(), previousSearch);
|
||||||
|
};
|
||||||
|
|
||||||
let lastSearchCancellationToken = useRef<CancellationToken | null>(null);
|
let lastSearchCancellationToken = useRef<CancellationToken | null>(null);
|
||||||
const onSearch = async ({
|
const onSearch = async ({
|
||||||
@ -394,7 +426,9 @@ export const SearchSection = ({
|
|||||||
|
|
||||||
setIsFetching(true);
|
setIsFetching(true);
|
||||||
setSearchResponse(initialSearchResponse);
|
setSearchResponse(initialSearchResponse);
|
||||||
setPreviousSearch(overrideMessage || query);
|
|
||||||
|
setPreviousSearch(currentSearch(overrideMessage));
|
||||||
|
|
||||||
const searchFnArgs = {
|
const searchFnArgs = {
|
||||||
query: overrideMessage || query,
|
query: overrideMessage || query,
|
||||||
sources: filterManager.selectedSources,
|
sources: filterManager.selectedSources,
|
||||||
@ -761,7 +795,7 @@ export const SearchSection = ({
|
|||||||
/>
|
/>
|
||||||
|
|
||||||
<FullSearchBar
|
<FullSearchBar
|
||||||
disabled={previousSearch === query}
|
disabled={!isSearchChanged()}
|
||||||
toggleAgentic={
|
toggleAgentic={
|
||||||
disabledAgentic ? undefined : toggleAgentic
|
disabledAgentic ? undefined : toggleAgentic
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user