various multi tenant improvements (#2803)

* various multi tenant improvements

* nit

* ensure consistent db session operations

* minor robustification
This commit is contained in:
pablodanswer 2024-10-15 13:10:57 -07:00 committed by GitHub
parent 0e6c2f0b51
commit bfe963988e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 84 additions and 34 deletions

View File

@ -12,7 +12,7 @@ from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
@ -36,7 +36,7 @@ def check_for_connector_deletion_task(tenant_id: str | None) -> None:
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(

View File

@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
from danswer.db.engine import get_session_with_tenant
@shared_task(
@ -23,7 +23,7 @@ from danswer.db.engine import get_sqlalchemy_engine # type: ignore
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any) -> int:
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
@ -35,7 +35,7 @@ def kombu_message_cleanup_task(self: Any) -> int:
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),

View File

@ -39,7 +39,6 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import DocumentSet
from danswer.db.models import UserGroup
@ -341,7 +340,9 @@ def monitor_document_set_taskset(
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
def monitor_connector_deletion_taskset(
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
@ -367,7 +368,7 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
if count > 0:
return
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
task_logger.warning(
@ -529,7 +530,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r)
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()

View File

@ -65,6 +65,7 @@ def _get_connector_runner(
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")

View File

@ -118,6 +118,9 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
class NotificationType(str, Enum):
REINDEX = "reindex"

View File

@ -4,6 +4,7 @@ from typing import Type
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import DocumentSourceRequiringTenantContext
from danswer.connectors.asana.connector import AsanaConnector
from danswer.connectors.axero.connector import AxeroConnector
from danswer.connectors.blob.connector import BlobStorageConnector
@ -134,8 +135,13 @@ def instantiate_connector(
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
tenant_id: str | None = None,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
if source in DocumentSourceRequiringTenantContext:
connector_specific_config["tenant_id"] = tenant_id
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json)

View File

@ -10,13 +10,14 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
@ -159,10 +160,12 @@ class LocalFileConnector(LoadConnector):
def __init__(
self,
file_locations: list[Path | str],
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.tenant_id = tenant_id
self.pdf_pass: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@ -171,7 +174,7 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(self.tenant_id) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(

View File

@ -435,14 +435,13 @@ def cancel_indexing_attempts_for_ccpair(
db_session.execute(stmt)
db_session.commit()
def cancel_indexing_attempts_past_model(
db_session: Session,
) -> None:
"""Stops all indexing attempts that are in progress or not started for
any embedding model that not present/future"""
db_session.execute(
update(IndexAttempt)
.where(
@ -455,8 +454,6 @@ def cancel_indexing_attempts_past_model(
.values(status=IndexingStatus.FAILED)
)
db_session.commit()
def count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id: int | None,

View File

@ -154,6 +154,7 @@ def update_cc_pair_status(
user=user,
get_editable=True,
)
if not cc_pair:
raise HTTPException(
status_code=400,
@ -163,7 +164,6 @@ def update_cc_pair_status(
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
# Just for good measure
cancel_indexing_attempts_past_model(db_session)
update_connector_credential_pair_from_id(
@ -172,6 +172,8 @@ def update_cc_pair_status(
status=status_update_request.status,
)
db_session.commit()
@router.put("/admin/cc-pair/{cc_pair_id}/name")
def update_cc_pair_name(

View File

@ -115,6 +115,7 @@ def set_new_search_settings(
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
db_session.commit()
return IdReturn(id=new_search_settings.id)

View File

@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.engine import get_session_context_manager
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit
@ -20,6 +21,7 @@ from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@ -39,11 +41,11 @@ def check_token_rate_limits(
versioned_rate_limit_strategy = fetch_versioned_implementation(
"danswer.server.query_and_chat.token_limit", "_check_token_rate_limits"
)
return versioned_rate_limit_strategy(user)
return versioned_rate_limit_strategy(user, current_tenant_id.get())
def _check_token_rate_limits(_: User | None) -> None:
_user_is_rate_limited_by_global()
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
_user_is_rate_limited_by_global(tenant_id)
"""
@ -51,8 +53,8 @@ Global rate limits
"""
def _user_is_rate_limited_by_global() -> None:
with get_session_context_manager() as db_session:
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
global_rate_limits = fetch_all_global_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)

View File

@ -12,7 +12,7 @@ from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit
@ -28,21 +28,21 @@ from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
def _check_token_rate_limits(user: User | None) -> None:
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global()
_user_is_rate_limited_by_global(tenant_id)
elif is_api_key_email_address(user.email):
# API keys are only rate limited by global settings
_user_is_rate_limited_by_global()
_user_is_rate_limited_by_global(tenant_id)
else:
run_functions_tuples_in_parallel(
[
(_user_is_rate_limited, (user.id,)),
(_user_is_rate_limited_by_group, (user.id,)),
(_user_is_rate_limited_by_global, ()),
(_user_is_rate_limited, (user.id, tenant_id)),
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
(_user_is_rate_limited_by_global, (tenant_id,)),
]
)
@ -52,8 +52,8 @@ User rate limits
"""
def _user_is_rate_limited(user_id: UUID) -> None:
with get_session_context_manager() as db_session:
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
@ -93,8 +93,8 @@ User Group rate limits
"""
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
with get_session_context_manager() as db_session:
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
if group_rate_limits:

View File

@ -206,6 +206,8 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
logger.notice(f"Deleting file {file_name}")
file_store.delete_file(file_name)
db_session.commit()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Delete a connector by its ID")

View File

@ -1,3 +1,4 @@
import { CLOUD_ENABLED } from "@/lib/constants";
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
import { NextRequest } from "next/server";
@ -6,8 +7,38 @@ export const POST = async (request: NextRequest) => {
// Needed since env variables don't work well on the client-side
const authTypeMetadata = await getAuthTypeMetadataSS();
const response = await logoutSS(authTypeMetadata.authType, request.headers);
if (!response || response.ok) {
if (response && !response.ok) {
return new Response(response.body, { status: response?.status });
}
// Delete cookies only if cloud is enabled (jwt auth)
if (CLOUD_ENABLED) {
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
const cookieOptions = {
path: "/",
secure: process.env.NODE_ENV === "production",
httpOnly: true,
sameSite: "lax" as const,
};
// Logout successful, delete cookies
const headers = new Headers();
cookiesToDelete.forEach((cookieName) => {
headers.append(
"Set-Cookie",
`${cookieName}=; Max-Age=0; ${Object.entries(cookieOptions)
.map(([key, value]) => `${key}=${value}`)
.join("; ")}`
);
});
return new Response(null, {
status: 204,
headers: headers,
});
} else {
return new Response(null, { status: 204 });
}
return new Response(response.body, { status: response?.status });
};

View File

@ -58,6 +58,7 @@ export const DISABLE_LLM_DOC_RELEVANCE =
export const CLOUD_ENABLED =
process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true";
export const REGISTRATION_URL =
process.env.INTERNAL_URL || "http://127.0.0.1:3001";