diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index 6a4c4da82..b13daff61 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -12,7 +12,7 @@ from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerRedisLocks from danswer.db.connector_credential_pair import get_connector_credential_pairs -from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.redis.redis_pool import get_redis_client @@ -36,7 +36,7 @@ def check_for_connector_deletion_task(tenant_id: str | None) -> None: if not lock_beat.acquire(blocking=False): return - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: try_generate_document_cc_pair_cleanup_tasks( diff --git a/backend/danswer/background/celery/tasks/periodic/tasks.py b/backend/danswer/background/celery/tasks/periodic/tasks.py index 99b1cab7e..d8da5ba9c 100644 --- a/backend/danswer/background/celery/tasks/periodic/tasks.py +++ b/backend/danswer/background/celery/tasks/periodic/tasks.py @@ -14,7 +14,7 @@ from sqlalchemy.orm import Session from danswer.background.celery.celery_app import task_logger from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import PostgresAdvisoryLocks -from danswer.db.engine import get_sqlalchemy_engine # type: ignore +from danswer.db.engine import get_session_with_tenant @shared_task( @@ -23,7 +23,7 @@ from danswer.db.engine import get_sqlalchemy_engine # type: ignore bind=True, base=AbortableTask, ) -def kombu_message_cleanup_task(self: Any) -> int: +def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int: """Runs periodically to clean up the kombu_message table""" # we will select messages older than this amount to clean up @@ -35,7 +35,7 @@ def kombu_message_cleanup_task(self: Any) -> int: ctx["deleted"] = 0 ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: # Exit the task if we can't take the advisory lock result = db_session.execute( text("SELECT pg_try_advisory_lock(:id)"), diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index e6a017b7a..c43de3a85 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -39,7 +39,6 @@ from danswer.db.document_set import fetch_document_sets_for_document from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import mark_document_set_as_synced from danswer.db.engine import get_session_with_tenant -from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import delete_index_attempts from danswer.db.models import DocumentSet from danswer.db.models import UserGroup @@ -341,7 +340,9 @@ def monitor_document_set_taskset( r.delete(rds.fence_key) -def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None: +def monitor_connector_deletion_taskset( + key_bytes: bytes, r: Redis, tenant_id: str | None +) -> None: fence_key = key_bytes.decode("utf-8") cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key) if cc_pair_id is None: @@ -367,7 +368,7 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None: if count > 0: return - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id) as db_session: cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) if not cc_pair: task_logger.warning( @@ -529,7 +530,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None: lock_beat.reacquire() for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): - monitor_connector_deletion_taskset(key_bytes, r) + monitor_connector_deletion_taskset(key_bytes, r, tenant_id) with get_session_with_tenant(tenant_id) as db_session: lock_beat.reacquire() diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 69c8a7393..c48d07ffd 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -65,6 +65,7 @@ def _get_connector_runner( input_type=task, connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config, credential=attempt.connector_credential_pair.credential, + tenant_id=tenant_id, ) except Exception as e: logger.exception(f"Unable to instantiate connector due to {e}") diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index e4aeb88c2..6b167246a 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -118,6 +118,9 @@ class DocumentSource(str, Enum): NOT_APPLICABLE = "not_applicable" +DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE] + + class NotificationType(str, Enum): REINDEX = "reindex" diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index 75e0d9bb2..52fb0194a 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -4,6 +4,7 @@ from typing import Type from sqlalchemy.orm import Session from danswer.configs.constants import DocumentSource +from danswer.configs.constants import DocumentSourceRequiringTenantContext from danswer.connectors.asana.connector import AsanaConnector from danswer.connectors.axero.connector import AxeroConnector from danswer.connectors.blob.connector import BlobStorageConnector @@ -134,8 +135,13 @@ def instantiate_connector( input_type: InputType, connector_specific_config: dict[str, Any], credential: Credential, + tenant_id: str | None = None, ) -> BaseConnector: connector_class = identify_connector_class(source, input_type) + + if source in DocumentSourceRequiringTenantContext: + connector_specific_config["tenant_id"] = tenant_id + connector = connector_class(**connector_specific_config) new_credentials = connector.load_credentials(credential.credential_json) diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index 8ef98716c..106fed8b2 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -10,13 +10,14 @@ from sqlalchemy.orm import Session from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource +from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.models import BasicExpertInfo from danswer.connectors.models import Document from danswer.connectors.models import Section -from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.engine import get_session_with_tenant from danswer.file_processing.extract_file_text import check_file_ext_is_valid from danswer.file_processing.extract_file_text import detect_encoding from danswer.file_processing.extract_file_text import extract_file_text @@ -159,10 +160,12 @@ class LocalFileConnector(LoadConnector): def __init__( self, file_locations: list[Path | str], + tenant_id: str = POSTGRES_DEFAULT_SCHEMA, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.file_locations = [Path(file_location) for file_location in file_locations] self.batch_size = batch_size + self.tenant_id = tenant_id self.pdf_pass: str | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: @@ -171,7 +174,7 @@ class LocalFileConnector(LoadConnector): def load_from_state(self) -> GenerateDocumentsOutput: documents: list[Document] = [] - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(self.tenant_id) as db_session: for file_path in self.file_locations: current_datetime = datetime.now(timezone.utc) files = _read_files_and_metadata( diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index 32e20d065..d9b1569e4 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -435,14 +435,13 @@ def cancel_indexing_attempts_for_ccpair( db_session.execute(stmt) - db_session.commit() - def cancel_indexing_attempts_past_model( db_session: Session, ) -> None: """Stops all indexing attempts that are in progress or not started for any embedding model that not present/future""" + db_session.execute( update(IndexAttempt) .where( @@ -455,8 +454,6 @@ def cancel_indexing_attempts_past_model( .values(status=IndexingStatus.FAILED) ) - db_session.commit() - def count_unique_cc_pairs_with_successful_index_attempts( search_settings_id: int | None, diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index ea513b5c2..d835a25e2 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -154,6 +154,7 @@ def update_cc_pair_status( user=user, get_editable=True, ) + if not cc_pair: raise HTTPException( status_code=400, @@ -163,7 +164,6 @@ def update_cc_pair_status( if status_update_request.status == ConnectorCredentialPairStatus.PAUSED: cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session) - # Just for good measure cancel_indexing_attempts_past_model(db_session) update_connector_credential_pair_from_id( @@ -172,6 +172,8 @@ def update_cc_pair_status( status=status_update_request.status, ) + db_session.commit() + @router.put("/admin/cc-pair/{cc_pair_id}/name") def update_cc_pair_name( diff --git a/backend/danswer/server/manage/search_settings.py b/backend/danswer/server/manage/search_settings.py index 6436a0bd8..79f690e5d 100644 --- a/backend/danswer/server/manage/search_settings.py +++ b/backend/danswer/server/manage/search_settings.py @@ -115,6 +115,7 @@ def set_new_search_settings( for cc_pair in get_connector_credential_pairs(db_session): resync_cc_pair(cc_pair, db_session=db_session) + db_session.commit() return IdReturn(id=new_search_settings.id) diff --git a/backend/danswer/server/query_and_chat/token_limit.py b/backend/danswer/server/query_and_chat/token_limit.py index 3f5d76bac..6221eae33 100644 --- a/backend/danswer/server/query_and_chat/token_limit.py +++ b/backend/danswer/server/query_and_chat/token_limit.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_user from danswer.db.engine import get_session_context_manager +from danswer.db.engine import get_session_with_tenant from danswer.db.models import ChatMessage from danswer.db.models import ChatSession from danswer.db.models import TokenRateLimit @@ -20,6 +21,7 @@ from danswer.db.models import User from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_versioned_implementation from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits +from shared_configs.configs import current_tenant_id logger = setup_logger() @@ -39,11 +41,11 @@ def check_token_rate_limits( versioned_rate_limit_strategy = fetch_versioned_implementation( "danswer.server.query_and_chat.token_limit", "_check_token_rate_limits" ) - return versioned_rate_limit_strategy(user) + return versioned_rate_limit_strategy(user, current_tenant_id.get()) -def _check_token_rate_limits(_: User | None) -> None: - _user_is_rate_limited_by_global() +def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None: + _user_is_rate_limited_by_global(tenant_id) """ @@ -51,8 +53,8 @@ Global rate limits """ -def _user_is_rate_limited_by_global() -> None: - with get_session_context_manager() as db_session: +def _user_is_rate_limited_by_global(tenant_id: str | None) -> None: + with get_session_with_tenant(tenant_id) as db_session: global_rate_limits = fetch_all_global_token_rate_limits( db_session=db_session, enabled_only=True, ordered=False ) diff --git a/backend/ee/danswer/server/query_and_chat/token_limit.py b/backend/ee/danswer/server/query_and_chat/token_limit.py index 538458fb6..b4c588dc4 100644 --- a/backend/ee/danswer/server/query_and_chat/token_limit.py +++ b/backend/ee/danswer/server/query_and_chat/token_limit.py @@ -12,7 +12,7 @@ from sqlalchemy import func from sqlalchemy import select from sqlalchemy.orm import Session -from danswer.db.engine import get_session_context_manager +from danswer.db.engine import get_session_with_tenant from danswer.db.models import ChatMessage from danswer.db.models import ChatSession from danswer.db.models import TokenRateLimit @@ -28,21 +28,21 @@ from ee.danswer.db.api_key import is_api_key_email_address from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits -def _check_token_rate_limits(user: User | None) -> None: +def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None: if user is None: # Unauthenticated users are only rate limited by global settings - _user_is_rate_limited_by_global() + _user_is_rate_limited_by_global(tenant_id) elif is_api_key_email_address(user.email): # API keys are only rate limited by global settings - _user_is_rate_limited_by_global() + _user_is_rate_limited_by_global(tenant_id) else: run_functions_tuples_in_parallel( [ - (_user_is_rate_limited, (user.id,)), - (_user_is_rate_limited_by_group, (user.id,)), - (_user_is_rate_limited_by_global, ()), + (_user_is_rate_limited, (user.id, tenant_id)), + (_user_is_rate_limited_by_group, (user.id, tenant_id)), + (_user_is_rate_limited_by_global, (tenant_id,)), ] ) @@ -52,8 +52,8 @@ User rate limits """ -def _user_is_rate_limited(user_id: UUID) -> None: - with get_session_context_manager() as db_session: +def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None: + with get_session_with_tenant(tenant_id) as db_session: user_rate_limits = fetch_all_user_token_rate_limits( db_session=db_session, enabled_only=True, ordered=False ) @@ -93,8 +93,8 @@ User Group rate limits """ -def _user_is_rate_limited_by_group(user_id: UUID) -> None: - with get_session_context_manager() as db_session: +def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None: + with get_session_with_tenant(tenant_id) as db_session: group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session) if group_rate_limits: diff --git a/backend/scripts/force_delete_connector_by_id.py b/backend/scripts/force_delete_connector_by_id.py index 0a9857304..241242f4a 100755 --- a/backend/scripts/force_delete_connector_by_id.py +++ b/backend/scripts/force_delete_connector_by_id.py @@ -206,6 +206,8 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None: logger.notice(f"Deleting file {file_name}") file_store.delete_file(file_name) + db_session.commit() + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Delete a connector by its ID") diff --git a/web/src/app/auth/logout/route.ts b/web/src/app/auth/logout/route.ts index 7de902c7a..e3bae04bb 100644 --- a/web/src/app/auth/logout/route.ts +++ b/web/src/app/auth/logout/route.ts @@ -1,3 +1,4 @@ +import { CLOUD_ENABLED } from "@/lib/constants"; import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS"; import { NextRequest } from "next/server"; @@ -6,8 +7,38 @@ export const POST = async (request: NextRequest) => { // Needed since env variables don't work well on the client-side const authTypeMetadata = await getAuthTypeMetadataSS(); const response = await logoutSS(authTypeMetadata.authType, request.headers); - if (!response || response.ok) { + + if (response && !response.ok) { + return new Response(response.body, { status: response?.status }); + } + + // Delete cookies only if cloud is enabled (jwt auth) + if (CLOUD_ENABLED) { + const cookiesToDelete = ["fastapiusersauth", "tenant_details"]; + const cookieOptions = { + path: "/", + secure: process.env.NODE_ENV === "production", + httpOnly: true, + sameSite: "lax" as const, + }; + + // Logout successful, delete cookies + const headers = new Headers(); + + cookiesToDelete.forEach((cookieName) => { + headers.append( + "Set-Cookie", + `${cookieName}=; Max-Age=0; ${Object.entries(cookieOptions) + .map(([key, value]) => `${key}=${value}`) + .join("; ")}` + ); + }); + + return new Response(null, { + status: 204, + headers: headers, + }); + } else { return new Response(null, { status: 204 }); } - return new Response(response.body, { status: response?.status }); }; diff --git a/web/src/lib/constants.ts b/web/src/lib/constants.ts index c0b916eba..15e5b5cbc 100644 --- a/web/src/lib/constants.ts +++ b/web/src/lib/constants.ts @@ -58,6 +58,7 @@ export const DISABLE_LLM_DOC_RELEVANCE = export const CLOUD_ENABLED = process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true"; + export const REGISTRATION_URL = process.env.INTERNAL_URL || "http://127.0.0.1:3001";