various multi tenant improvements (#2803)

* various multi tenant improvements

* nit

* ensure consistent db session operations

* minor robustification
This commit is contained in:
pablodanswer
2024-10-15 13:10:57 -07:00
committed by GitHub
parent 0e6c2f0b51
commit bfe963988e
15 changed files with 84 additions and 34 deletions

View File

@ -12,7 +12,7 @@ from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client from danswer.redis.redis_pool import get_redis_client
@ -36,7 +36,7 @@ def check_for_connector_deletion_task(tenant_id: str | None) -> None:
if not lock_beat.acquire(blocking=False): if not lock_beat.acquire(blocking=False):
return return
with Session(get_sqlalchemy_engine()) as db_session: with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session) cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs: for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks( try_generate_document_cc_pair_cleanup_tasks(

View File

@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger from danswer.background.celery.celery_app import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_sqlalchemy_engine # type: ignore from danswer.db.engine import get_session_with_tenant
@shared_task( @shared_task(
@ -23,7 +23,7 @@ from danswer.db.engine import get_sqlalchemy_engine # type: ignore
bind=True, bind=True,
base=AbortableTask, base=AbortableTask,
) )
def kombu_message_cleanup_task(self: Any) -> int: def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table""" """Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up # we will select messages older than this amount to clean up
@ -35,7 +35,7 @@ def kombu_message_cleanup_task(self: Any) -> int:
ctx["deleted"] = 0 ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine()) as db_session: with get_session_with_tenant(tenant_id) as db_session:
# Exit the task if we can't take the advisory lock # Exit the task if we can't take the advisory lock
result = db_session.execute( result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"), text("SELECT pg_try_advisory_lock(:id)"),

View File

@ -39,7 +39,6 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import DocumentSet from danswer.db.models import DocumentSet
from danswer.db.models import UserGroup from danswer.db.models import UserGroup
@ -341,7 +340,9 @@ def monitor_document_set_taskset(
r.delete(rds.fence_key) r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None: def monitor_connector_deletion_taskset(
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
fence_key = key_bytes.decode("utf-8") fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key) cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None: if cc_pair_id is None:
@ -367,7 +368,7 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
if count > 0: if count > 0:
return return
with Session(get_sqlalchemy_engine()) as db_session: with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair: if not cc_pair:
task_logger.warning( task_logger.warning(
@ -529,7 +530,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
lock_beat.reacquire() lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r) monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire() lock_beat.reacquire()

View File

@ -65,6 +65,7 @@ def _get_connector_runner(
input_type=task, input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config, connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential, credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
) )
except Exception as e: except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}") logger.exception(f"Unable to instantiate connector due to {e}")

View File

@ -118,6 +118,9 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable" NOT_APPLICABLE = "not_applicable"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
class NotificationType(str, Enum): class NotificationType(str, Enum):
REINDEX = "reindex" REINDEX = "reindex"

View File

@ -4,6 +4,7 @@ from typing import Type
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.configs.constants import DocumentSourceRequiringTenantContext
from danswer.connectors.asana.connector import AsanaConnector from danswer.connectors.asana.connector import AsanaConnector
from danswer.connectors.axero.connector import AxeroConnector from danswer.connectors.axero.connector import AxeroConnector
from danswer.connectors.blob.connector import BlobStorageConnector from danswer.connectors.blob.connector import BlobStorageConnector
@ -134,8 +135,13 @@ def instantiate_connector(
input_type: InputType, input_type: InputType,
connector_specific_config: dict[str, Any], connector_specific_config: dict[str, Any],
credential: Credential, credential: Credential,
tenant_id: str | None = None,
) -> BaseConnector: ) -> BaseConnector:
connector_class = identify_connector_class(source, input_type) connector_class = identify_connector_class(source, input_type)
if source in DocumentSourceRequiringTenantContext:
connector_specific_config["tenant_id"] = tenant_id
connector = connector_class(**connector_specific_config) connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json) new_credentials = connector.load_credentials(credential.credential_json)

View File

@ -10,13 +10,14 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import BasicExpertInfo from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_session_with_tenant
from danswer.file_processing.extract_file_text import check_file_ext_is_valid from danswer.file_processing.extract_file_text import check_file_ext_is_valid
from danswer.file_processing.extract_file_text import detect_encoding from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text from danswer.file_processing.extract_file_text import extract_file_text
@ -159,10 +160,12 @@ class LocalFileConnector(LoadConnector):
def __init__( def __init__(
self, self,
file_locations: list[Path | str], file_locations: list[Path | str],
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
batch_size: int = INDEX_BATCH_SIZE, batch_size: int = INDEX_BATCH_SIZE,
) -> None: ) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations] self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size self.batch_size = batch_size
self.tenant_id = tenant_id
self.pdf_pass: str | None = None self.pdf_pass: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@ -171,7 +174,7 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput: def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = [] documents: list[Document] = []
with Session(get_sqlalchemy_engine()) as db_session: with get_session_with_tenant(self.tenant_id) as db_session:
for file_path in self.file_locations: for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc) current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata( files = _read_files_and_metadata(

View File

@ -435,14 +435,13 @@ def cancel_indexing_attempts_for_ccpair(
db_session.execute(stmt) db_session.execute(stmt)
db_session.commit()
def cancel_indexing_attempts_past_model( def cancel_indexing_attempts_past_model(
db_session: Session, db_session: Session,
) -> None: ) -> None:
"""Stops all indexing attempts that are in progress or not started for """Stops all indexing attempts that are in progress or not started for
any embedding model that not present/future""" any embedding model that not present/future"""
db_session.execute( db_session.execute(
update(IndexAttempt) update(IndexAttempt)
.where( .where(
@ -455,8 +454,6 @@ def cancel_indexing_attempts_past_model(
.values(status=IndexingStatus.FAILED) .values(status=IndexingStatus.FAILED)
) )
db_session.commit()
def count_unique_cc_pairs_with_successful_index_attempts( def count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id: int | None, search_settings_id: int | None,

View File

@ -154,6 +154,7 @@ def update_cc_pair_status(
user=user, user=user,
get_editable=True, get_editable=True,
) )
if not cc_pair: if not cc_pair:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -163,7 +164,6 @@ def update_cc_pair_status(
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED: if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session) cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
# Just for good measure
cancel_indexing_attempts_past_model(db_session) cancel_indexing_attempts_past_model(db_session)
update_connector_credential_pair_from_id( update_connector_credential_pair_from_id(
@ -172,6 +172,8 @@ def update_cc_pair_status(
status=status_update_request.status, status=status_update_request.status,
) )
db_session.commit()
@router.put("/admin/cc-pair/{cc_pair_id}/name") @router.put("/admin/cc-pair/{cc_pair_id}/name")
def update_cc_pair_name( def update_cc_pair_name(

View File

@ -115,6 +115,7 @@ def set_new_search_settings(
for cc_pair in get_connector_credential_pairs(db_session): for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session) resync_cc_pair(cc_pair, db_session=db_session)
db_session.commit()
return IdReturn(id=new_search_settings.id) return IdReturn(id=new_search_settings.id)

View File

@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.db.engine import get_session_context_manager from danswer.db.engine import get_session_context_manager
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit from danswer.db.models import TokenRateLimit
@ -20,6 +21,7 @@ from danswer.db.models import User
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from shared_configs.configs import current_tenant_id
logger = setup_logger() logger = setup_logger()
@ -39,11 +41,11 @@ def check_token_rate_limits(
versioned_rate_limit_strategy = fetch_versioned_implementation( versioned_rate_limit_strategy = fetch_versioned_implementation(
"danswer.server.query_and_chat.token_limit", "_check_token_rate_limits" "danswer.server.query_and_chat.token_limit", "_check_token_rate_limits"
) )
return versioned_rate_limit_strategy(user) return versioned_rate_limit_strategy(user, current_tenant_id.get())
def _check_token_rate_limits(_: User | None) -> None: def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
_user_is_rate_limited_by_global() _user_is_rate_limited_by_global(tenant_id)
""" """
@ -51,8 +53,8 @@ Global rate limits
""" """
def _user_is_rate_limited_by_global() -> None: def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
with get_session_context_manager() as db_session: with get_session_with_tenant(tenant_id) as db_session:
global_rate_limits = fetch_all_global_token_rate_limits( global_rate_limits = fetch_all_global_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False db_session=db_session, enabled_only=True, ordered=False
) )

View File

@ -12,7 +12,7 @@ from sqlalchemy import func
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit from danswer.db.models import TokenRateLimit
@ -28,21 +28,21 @@ from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
def _check_token_rate_limits(user: User | None) -> None: def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
if user is None: if user is None:
# Unauthenticated users are only rate limited by global settings # Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global() _user_is_rate_limited_by_global(tenant_id)
elif is_api_key_email_address(user.email): elif is_api_key_email_address(user.email):
# API keys are only rate limited by global settings # API keys are only rate limited by global settings
_user_is_rate_limited_by_global() _user_is_rate_limited_by_global(tenant_id)
else: else:
run_functions_tuples_in_parallel( run_functions_tuples_in_parallel(
[ [
(_user_is_rate_limited, (user.id,)), (_user_is_rate_limited, (user.id, tenant_id)),
(_user_is_rate_limited_by_group, (user.id,)), (_user_is_rate_limited_by_group, (user.id, tenant_id)),
(_user_is_rate_limited_by_global, ()), (_user_is_rate_limited_by_global, (tenant_id,)),
] ]
) )
@ -52,8 +52,8 @@ User rate limits
""" """
def _user_is_rate_limited(user_id: UUID) -> None: def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
with get_session_context_manager() as db_session: with get_session_with_tenant(tenant_id) as db_session:
user_rate_limits = fetch_all_user_token_rate_limits( user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False db_session=db_session, enabled_only=True, ordered=False
) )
@ -93,8 +93,8 @@ User Group rate limits
""" """
def _user_is_rate_limited_by_group(user_id: UUID) -> None: def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_context_manager() as db_session: with get_session_with_tenant(tenant_id) as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session) group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
if group_rate_limits: if group_rate_limits:

View File

@ -206,6 +206,8 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
logger.notice(f"Deleting file {file_name}") logger.notice(f"Deleting file {file_name}")
file_store.delete_file(file_name) file_store.delete_file(file_name)
db_session.commit()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Delete a connector by its ID") parser = argparse.ArgumentParser(description="Delete a connector by its ID")

View File

@ -1,3 +1,4 @@
import { CLOUD_ENABLED } from "@/lib/constants";
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS"; import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
import { NextRequest } from "next/server"; import { NextRequest } from "next/server";
@ -6,8 +7,38 @@ export const POST = async (request: NextRequest) => {
// Needed since env variables don't work well on the client-side // Needed since env variables don't work well on the client-side
const authTypeMetadata = await getAuthTypeMetadataSS(); const authTypeMetadata = await getAuthTypeMetadataSS();
const response = await logoutSS(authTypeMetadata.authType, request.headers); const response = await logoutSS(authTypeMetadata.authType, request.headers);
if (!response || response.ok) {
if (response && !response.ok) {
return new Response(response.body, { status: response?.status });
}
// Delete cookies only if cloud is enabled (jwt auth)
if (CLOUD_ENABLED) {
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
const cookieOptions = {
path: "/",
secure: process.env.NODE_ENV === "production",
httpOnly: true,
sameSite: "lax" as const,
};
// Logout successful, delete cookies
const headers = new Headers();
cookiesToDelete.forEach((cookieName) => {
headers.append(
"Set-Cookie",
`${cookieName}=; Max-Age=0; ${Object.entries(cookieOptions)
.map(([key, value]) => `${key}=${value}`)
.join("; ")}`
);
});
return new Response(null, {
status: 204,
headers: headers,
});
} else {
return new Response(null, { status: 204 }); return new Response(null, { status: 204 });
} }
return new Response(response.body, { status: response?.status });
}; };

View File

@ -58,6 +58,7 @@ export const DISABLE_LLM_DOC_RELEVANCE =
export const CLOUD_ENABLED = export const CLOUD_ENABLED =
process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true"; process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true";
export const REGISTRATION_URL = export const REGISTRATION_URL =
process.env.INTERNAL_URL || "http://127.0.0.1:3001"; process.env.INTERNAL_URL || "http://127.0.0.1:3001";