mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 11:28:09 +02:00
various multi tenant improvements (#2803)
* various multi tenant improvements * nit * ensure consistent db session operations * minor robustification
This commit is contained in:
parent
0e6c2f0b51
commit
bfe963988e
@ -12,7 +12,7 @@ from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
@ -36,7 +36,7 @@ def check_for_connector_deletion_task(tenant_id: str | None) -> None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
@ -23,7 +23,7 @@ from danswer.db.engine import get_sqlalchemy_engine # type: ignore
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any) -> int:
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
@ -35,7 +35,7 @@ def kombu_message_cleanup_task(self: Any) -> int:
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
|
@ -39,7 +39,6 @@ from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import UserGroup
|
||||
@ -341,7 +340,9 @@ def monitor_document_set_taskset(
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
def monitor_connector_deletion_taskset(
|
||||
key_bytes: bytes, r: Redis, tenant_id: str | None
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id is None:
|
||||
@ -367,7 +368,7 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
@ -529,7 +530,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
monitor_connector_deletion_taskset(key_bytes, r)
|
||||
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
|
@ -65,6 +65,7 @@ def _get_connector_runner(
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
@ -118,6 +118,9 @@ class DocumentSource(str, Enum):
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import Type
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from danswer.connectors.asana.connector import AsanaConnector
|
||||
from danswer.connectors.axero.connector import AxeroConnector
|
||||
from danswer.connectors.blob.connector import BlobStorageConnector
|
||||
@ -134,8 +135,13 @@ def instantiate_connector(
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
tenant_id: str | None = None,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
if source in DocumentSourceRequiringTenantContext:
|
||||
connector_specific_config["tenant_id"] = tenant_id
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
|
@ -10,13 +10,14 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
|
||||
from danswer.file_processing.extract_file_text import detect_encoding
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
@ -159,10 +160,12 @@ class LocalFileConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.tenant_id = tenant_id
|
||||
self.pdf_pass: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@ -171,7 +174,7 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with get_session_with_tenant(self.tenant_id) as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
|
@ -435,14 +435,13 @@ def cancel_indexing_attempts_for_ccpair(
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cancel_indexing_attempts_past_model(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Stops all indexing attempts that are in progress or not started for
|
||||
any embedding model that not present/future"""
|
||||
|
||||
db_session.execute(
|
||||
update(IndexAttempt)
|
||||
.where(
|
||||
@ -455,8 +454,6 @@ def cancel_indexing_attempts_past_model(
|
||||
.values(status=IndexingStatus.FAILED)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
search_settings_id: int | None,
|
||||
|
@ -154,6 +154,7 @@ def update_cc_pair_status(
|
||||
user=user,
|
||||
get_editable=True,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -163,7 +164,6 @@ def update_cc_pair_status(
|
||||
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
|
||||
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
|
||||
|
||||
# Just for good measure
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
update_connector_credential_pair_from_id(
|
||||
@ -172,6 +172,8 @@ def update_cc_pair_status(
|
||||
status=status_update_request.status,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.put("/admin/cc-pair/{cc_pair_id}/name")
|
||||
def update_cc_pair_name(
|
||||
|
@ -115,6 +115,7 @@ def set_new_search_settings(
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
return IdReturn(id=new_search_settings.id)
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
from danswer.db.models import TokenRateLimit
|
||||
@ -20,6 +21,7 @@ from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -39,11 +41,11 @@ def check_token_rate_limits(
|
||||
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
||||
"danswer.server.query_and_chat.token_limit", "_check_token_rate_limits"
|
||||
)
|
||||
return versioned_rate_limit_strategy(user)
|
||||
return versioned_rate_limit_strategy(user, current_tenant_id.get())
|
||||
|
||||
|
||||
def _check_token_rate_limits(_: User | None) -> None:
|
||||
_user_is_rate_limited_by_global()
|
||||
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
|
||||
|
||||
"""
|
||||
@ -51,8 +53,8 @@ Global rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_global() -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
global_rate_limits = fetch_all_global_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
|
@ -12,7 +12,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
from danswer.db.models import TokenRateLimit
|
||||
@ -28,21 +28,21 @@ from ee.danswer.db.api_key import is_api_key_email_address
|
||||
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User | None) -> None:
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
|
||||
if user is None:
|
||||
# Unauthenticated users are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global()
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
|
||||
elif is_api_key_email_address(user.email):
|
||||
# API keys are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global()
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
|
||||
else:
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(_user_is_rate_limited, (user.id,)),
|
||||
(_user_is_rate_limited_by_group, (user.id,)),
|
||||
(_user_is_rate_limited_by_global, ()),
|
||||
(_user_is_rate_limited, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_global, (tenant_id,)),
|
||||
]
|
||||
)
|
||||
|
||||
@ -52,8 +52,8 @@ User rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited(user_id: UUID) -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
@ -93,8 +93,8 @@ User Group rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||
|
||||
if group_rate_limits:
|
||||
|
@ -206,6 +206,8 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
|
||||
logger.notice(f"Deleting file {file_name}")
|
||||
file_store.delete_file(file_name)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Delete a connector by its ID")
|
||||
|
@ -1,3 +1,4 @@
|
||||
import { CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
|
||||
import { NextRequest } from "next/server";
|
||||
|
||||
@ -6,8 +7,38 @@ export const POST = async (request: NextRequest) => {
|
||||
// Needed since env variables don't work well on the client-side
|
||||
const authTypeMetadata = await getAuthTypeMetadataSS();
|
||||
const response = await logoutSS(authTypeMetadata.authType, request.headers);
|
||||
if (!response || response.ok) {
|
||||
|
||||
if (response && !response.ok) {
|
||||
return new Response(response.body, { status: response?.status });
|
||||
}
|
||||
|
||||
// Delete cookies only if cloud is enabled (jwt auth)
|
||||
if (CLOUD_ENABLED) {
|
||||
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
|
||||
const cookieOptions = {
|
||||
path: "/",
|
||||
secure: process.env.NODE_ENV === "production",
|
||||
httpOnly: true,
|
||||
sameSite: "lax" as const,
|
||||
};
|
||||
|
||||
// Logout successful, delete cookies
|
||||
const headers = new Headers();
|
||||
|
||||
cookiesToDelete.forEach((cookieName) => {
|
||||
headers.append(
|
||||
"Set-Cookie",
|
||||
`${cookieName}=; Max-Age=0; ${Object.entries(cookieOptions)
|
||||
.map(([key, value]) => `${key}=${value}`)
|
||||
.join("; ")}`
|
||||
);
|
||||
});
|
||||
|
||||
return new Response(null, {
|
||||
status: 204,
|
||||
headers: headers,
|
||||
});
|
||||
} else {
|
||||
return new Response(null, { status: 204 });
|
||||
}
|
||||
return new Response(response.body, { status: response?.status });
|
||||
};
|
||||
|
@ -58,6 +58,7 @@ export const DISABLE_LLM_DOC_RELEVANCE =
|
||||
|
||||
export const CLOUD_ENABLED =
|
||||
process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true";
|
||||
|
||||
export const REGISTRATION_URL =
|
||||
process.env.INTERNAL_URL || "http://127.0.0.1:3001";
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user