mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-07 13:10:24 +02:00
various multi tenant improvements (#2803)
* various multi tenant improvements * nit * ensure consistent db session operations * minor robustification
This commit is contained in:
@ -12,7 +12,7 @@ from danswer.configs.app_configs import JOB_TIMEOUT
|
|||||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||||
from danswer.configs.constants import DanswerRedisLocks
|
from danswer.configs.constants import DanswerRedisLocks
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||||
from danswer.db.models import ConnectorCredentialPair
|
from danswer.db.models import ConnectorCredentialPair
|
||||||
from danswer.redis.redis_pool import get_redis_client
|
from danswer.redis.redis_pool import get_redis_client
|
||||||
@ -36,7 +36,7 @@ def check_for_connector_deletion_task(tenant_id: str | None) -> None:
|
|||||||
if not lock_beat.acquire(blocking=False):
|
if not lock_beat.acquire(blocking=False):
|
||||||
return
|
return
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
cc_pairs = get_connector_credential_pairs(db_session)
|
cc_pairs = get_connector_credential_pairs(db_session)
|
||||||
for cc_pair in cc_pairs:
|
for cc_pair in cc_pairs:
|
||||||
try_generate_document_cc_pair_cleanup_tasks(
|
try_generate_document_cc_pair_cleanup_tasks(
|
||||||
|
@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.background.celery.celery_app import task_logger
|
from danswer.background.celery.celery_app import task_logger
|
||||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||||
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
|
from danswer.db.engine import get_session_with_tenant
|
||||||
|
|
||||||
|
|
||||||
@shared_task(
|
@shared_task(
|
||||||
@ -23,7 +23,7 @@ from danswer.db.engine import get_sqlalchemy_engine # type: ignore
|
|||||||
bind=True,
|
bind=True,
|
||||||
base=AbortableTask,
|
base=AbortableTask,
|
||||||
)
|
)
|
||||||
def kombu_message_cleanup_task(self: Any) -> int:
|
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||||
"""Runs periodically to clean up the kombu_message table"""
|
"""Runs periodically to clean up the kombu_message table"""
|
||||||
|
|
||||||
# we will select messages older than this amount to clean up
|
# we will select messages older than this amount to clean up
|
||||||
@ -35,7 +35,7 @@ def kombu_message_cleanup_task(self: Any) -> int:
|
|||||||
ctx["deleted"] = 0
|
ctx["deleted"] = 0
|
||||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
# Exit the task if we can't take the advisory lock
|
# Exit the task if we can't take the advisory lock
|
||||||
result = db_session.execute(
|
result = db_session.execute(
|
||||||
text("SELECT pg_try_advisory_lock(:id)"),
|
text("SELECT pg_try_advisory_lock(:id)"),
|
||||||
|
@ -39,7 +39,6 @@ from danswer.db.document_set import fetch_document_sets_for_document
|
|||||||
from danswer.db.document_set import get_document_set_by_id
|
from danswer.db.document_set import get_document_set_by_id
|
||||||
from danswer.db.document_set import mark_document_set_as_synced
|
from danswer.db.document_set import mark_document_set_as_synced
|
||||||
from danswer.db.engine import get_session_with_tenant
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
|
||||||
from danswer.db.index_attempt import delete_index_attempts
|
from danswer.db.index_attempt import delete_index_attempts
|
||||||
from danswer.db.models import DocumentSet
|
from danswer.db.models import DocumentSet
|
||||||
from danswer.db.models import UserGroup
|
from danswer.db.models import UserGroup
|
||||||
@ -341,7 +340,9 @@ def monitor_document_set_taskset(
|
|||||||
r.delete(rds.fence_key)
|
r.delete(rds.fence_key)
|
||||||
|
|
||||||
|
|
||||||
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
def monitor_connector_deletion_taskset(
|
||||||
|
key_bytes: bytes, r: Redis, tenant_id: str | None
|
||||||
|
) -> None:
|
||||||
fence_key = key_bytes.decode("utf-8")
|
fence_key = key_bytes.decode("utf-8")
|
||||||
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||||
if cc_pair_id is None:
|
if cc_pair_id is None:
|
||||||
@ -367,7 +368,7 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
|||||||
if count > 0:
|
if count > 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||||
if not cc_pair:
|
if not cc_pair:
|
||||||
task_logger.warning(
|
task_logger.warning(
|
||||||
@ -529,7 +530,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
|
|||||||
|
|
||||||
lock_beat.reacquire()
|
lock_beat.reacquire()
|
||||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||||
monitor_connector_deletion_taskset(key_bytes, r)
|
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||||
|
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
lock_beat.reacquire()
|
lock_beat.reacquire()
|
||||||
|
@ -65,6 +65,7 @@ def _get_connector_runner(
|
|||||||
input_type=task,
|
input_type=task,
|
||||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||||
credential=attempt.connector_credential_pair.credential,
|
credential=attempt.connector_credential_pair.credential,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||||
|
@ -118,6 +118,9 @@ class DocumentSource(str, Enum):
|
|||||||
NOT_APPLICABLE = "not_applicable"
|
NOT_APPLICABLE = "not_applicable"
|
||||||
|
|
||||||
|
|
||||||
|
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||||
|
|
||||||
|
|
||||||
class NotificationType(str, Enum):
|
class NotificationType(str, Enum):
|
||||||
REINDEX = "reindex"
|
REINDEX = "reindex"
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from typing import Type
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.configs.constants import DocumentSourceRequiringTenantContext
|
||||||
from danswer.connectors.asana.connector import AsanaConnector
|
from danswer.connectors.asana.connector import AsanaConnector
|
||||||
from danswer.connectors.axero.connector import AxeroConnector
|
from danswer.connectors.axero.connector import AxeroConnector
|
||||||
from danswer.connectors.blob.connector import BlobStorageConnector
|
from danswer.connectors.blob.connector import BlobStorageConnector
|
||||||
@ -134,8 +135,13 @@ def instantiate_connector(
|
|||||||
input_type: InputType,
|
input_type: InputType,
|
||||||
connector_specific_config: dict[str, Any],
|
connector_specific_config: dict[str, Any],
|
||||||
credential: Credential,
|
credential: Credential,
|
||||||
|
tenant_id: str | None = None,
|
||||||
) -> BaseConnector:
|
) -> BaseConnector:
|
||||||
connector_class = identify_connector_class(source, input_type)
|
connector_class = identify_connector_class(source, input_type)
|
||||||
|
|
||||||
|
if source in DocumentSourceRequiringTenantContext:
|
||||||
|
connector_specific_config["tenant_id"] = tenant_id
|
||||||
|
|
||||||
connector = connector_class(**connector_specific_config)
|
connector = connector_class(**connector_specific_config)
|
||||||
new_credentials = connector.load_credentials(credential.credential_json)
|
new_credentials = connector.load_credentials(credential.credential_json)
|
||||||
|
|
||||||
|
@ -10,13 +10,14 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||||
from danswer.connectors.interfaces import LoadConnector
|
from danswer.connectors.interfaces import LoadConnector
|
||||||
from danswer.connectors.models import BasicExpertInfo
|
from danswer.connectors.models import BasicExpertInfo
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
|
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
|
||||||
from danswer.file_processing.extract_file_text import detect_encoding
|
from danswer.file_processing.extract_file_text import detect_encoding
|
||||||
from danswer.file_processing.extract_file_text import extract_file_text
|
from danswer.file_processing.extract_file_text import extract_file_text
|
||||||
@ -159,10 +160,12 @@ class LocalFileConnector(LoadConnector):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
file_locations: list[Path | str],
|
file_locations: list[Path | str],
|
||||||
|
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||||
batch_size: int = INDEX_BATCH_SIZE,
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.tenant_id = tenant_id
|
||||||
self.pdf_pass: str | None = None
|
self.pdf_pass: str | None = None
|
||||||
|
|
||||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
@ -171,7 +174,7 @@ class LocalFileConnector(LoadConnector):
|
|||||||
|
|
||||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||||
documents: list[Document] = []
|
documents: list[Document] = []
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with get_session_with_tenant(self.tenant_id) as db_session:
|
||||||
for file_path in self.file_locations:
|
for file_path in self.file_locations:
|
||||||
current_datetime = datetime.now(timezone.utc)
|
current_datetime = datetime.now(timezone.utc)
|
||||||
files = _read_files_and_metadata(
|
files = _read_files_and_metadata(
|
||||||
|
@ -435,14 +435,13 @@ def cancel_indexing_attempts_for_ccpair(
|
|||||||
|
|
||||||
db_session.execute(stmt)
|
db_session.execute(stmt)
|
||||||
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def cancel_indexing_attempts_past_model(
|
def cancel_indexing_attempts_past_model(
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Stops all indexing attempts that are in progress or not started for
|
"""Stops all indexing attempts that are in progress or not started for
|
||||||
any embedding model that not present/future"""
|
any embedding model that not present/future"""
|
||||||
|
|
||||||
db_session.execute(
|
db_session.execute(
|
||||||
update(IndexAttempt)
|
update(IndexAttempt)
|
||||||
.where(
|
.where(
|
||||||
@ -455,8 +454,6 @@ def cancel_indexing_attempts_past_model(
|
|||||||
.values(status=IndexingStatus.FAILED)
|
.values(status=IndexingStatus.FAILED)
|
||||||
)
|
)
|
||||||
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def count_unique_cc_pairs_with_successful_index_attempts(
|
def count_unique_cc_pairs_with_successful_index_attempts(
|
||||||
search_settings_id: int | None,
|
search_settings_id: int | None,
|
||||||
|
@ -154,6 +154,7 @@ def update_cc_pair_status(
|
|||||||
user=user,
|
user=user,
|
||||||
get_editable=True,
|
get_editable=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not cc_pair:
|
if not cc_pair:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
@ -163,7 +164,6 @@ def update_cc_pair_status(
|
|||||||
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
|
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
|
||||||
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
|
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
|
||||||
|
|
||||||
# Just for good measure
|
|
||||||
cancel_indexing_attempts_past_model(db_session)
|
cancel_indexing_attempts_past_model(db_session)
|
||||||
|
|
||||||
update_connector_credential_pair_from_id(
|
update_connector_credential_pair_from_id(
|
||||||
@ -172,6 +172,8 @@ def update_cc_pair_status(
|
|||||||
status=status_update_request.status,
|
status=status_update_request.status,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
@router.put("/admin/cc-pair/{cc_pair_id}/name")
|
@router.put("/admin/cc-pair/{cc_pair_id}/name")
|
||||||
def update_cc_pair_name(
|
def update_cc_pair_name(
|
||||||
|
@ -115,6 +115,7 @@ def set_new_search_settings(
|
|||||||
for cc_pair in get_connector_credential_pairs(db_session):
|
for cc_pair in get_connector_credential_pairs(db_session):
|
||||||
resync_cc_pair(cc_pair, db_session=db_session)
|
resync_cc_pair(cc_pair, db_session=db_session)
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
return IdReturn(id=new_search_settings.id)
|
return IdReturn(id=new_search_settings.id)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
from danswer.db.engine import get_session_context_manager
|
from danswer.db.engine import get_session_context_manager
|
||||||
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
from danswer.db.models import ChatSession
|
from danswer.db.models import ChatSession
|
||||||
from danswer.db.models import TokenRateLimit
|
from danswer.db.models import TokenRateLimit
|
||||||
@ -20,6 +21,7 @@ from danswer.db.models import User
|
|||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||||
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
|
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
|
||||||
|
from shared_configs.configs import current_tenant_id
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -39,11 +41,11 @@ def check_token_rate_limits(
|
|||||||
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
||||||
"danswer.server.query_and_chat.token_limit", "_check_token_rate_limits"
|
"danswer.server.query_and_chat.token_limit", "_check_token_rate_limits"
|
||||||
)
|
)
|
||||||
return versioned_rate_limit_strategy(user)
|
return versioned_rate_limit_strategy(user, current_tenant_id.get())
|
||||||
|
|
||||||
|
|
||||||
def _check_token_rate_limits(_: User | None) -> None:
|
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
|
||||||
_user_is_rate_limited_by_global()
|
_user_is_rate_limited_by_global(tenant_id)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -51,8 +53,8 @@ Global rate limits
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _user_is_rate_limited_by_global() -> None:
|
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
global_rate_limits = fetch_all_global_token_rate_limits(
|
global_rate_limits = fetch_all_global_token_rate_limits(
|
||||||
db_session=db_session, enabled_only=True, ordered=False
|
db_session=db_session, enabled_only=True, ordered=False
|
||||||
)
|
)
|
||||||
|
@ -12,7 +12,7 @@ from sqlalchemy import func
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.db.engine import get_session_context_manager
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
from danswer.db.models import ChatSession
|
from danswer.db.models import ChatSession
|
||||||
from danswer.db.models import TokenRateLimit
|
from danswer.db.models import TokenRateLimit
|
||||||
@ -28,21 +28,21 @@ from ee.danswer.db.api_key import is_api_key_email_address
|
|||||||
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
|
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
|
||||||
|
|
||||||
|
|
||||||
def _check_token_rate_limits(user: User | None) -> None:
|
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
|
||||||
if user is None:
|
if user is None:
|
||||||
# Unauthenticated users are only rate limited by global settings
|
# Unauthenticated users are only rate limited by global settings
|
||||||
_user_is_rate_limited_by_global()
|
_user_is_rate_limited_by_global(tenant_id)
|
||||||
|
|
||||||
elif is_api_key_email_address(user.email):
|
elif is_api_key_email_address(user.email):
|
||||||
# API keys are only rate limited by global settings
|
# API keys are only rate limited by global settings
|
||||||
_user_is_rate_limited_by_global()
|
_user_is_rate_limited_by_global(tenant_id)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
run_functions_tuples_in_parallel(
|
run_functions_tuples_in_parallel(
|
||||||
[
|
[
|
||||||
(_user_is_rate_limited, (user.id,)),
|
(_user_is_rate_limited, (user.id, tenant_id)),
|
||||||
(_user_is_rate_limited_by_group, (user.id,)),
|
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
|
||||||
(_user_is_rate_limited_by_global, ()),
|
(_user_is_rate_limited_by_global, (tenant_id,)),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -52,8 +52,8 @@ User rate limits
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _user_is_rate_limited(user_id: UUID) -> None:
|
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||||
db_session=db_session, enabled_only=True, ordered=False
|
db_session=db_session, enabled_only=True, ordered=False
|
||||||
)
|
)
|
||||||
@ -93,8 +93,8 @@ User Group rate limits
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
|
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||||
|
|
||||||
if group_rate_limits:
|
if group_rate_limits:
|
||||||
|
@ -206,6 +206,8 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
|
|||||||
logger.notice(f"Deleting file {file_name}")
|
logger.notice(f"Deleting file {file_name}")
|
||||||
file_store.delete_file(file_name)
|
file_store.delete_file(file_name)
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Delete a connector by its ID")
|
parser = argparse.ArgumentParser(description="Delete a connector by its ID")
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import { CLOUD_ENABLED } from "@/lib/constants";
|
||||||
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
|
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
|
||||||
import { NextRequest } from "next/server";
|
import { NextRequest } from "next/server";
|
||||||
|
|
||||||
@ -6,8 +7,38 @@ export const POST = async (request: NextRequest) => {
|
|||||||
// Needed since env variables don't work well on the client-side
|
// Needed since env variables don't work well on the client-side
|
||||||
const authTypeMetadata = await getAuthTypeMetadataSS();
|
const authTypeMetadata = await getAuthTypeMetadataSS();
|
||||||
const response = await logoutSS(authTypeMetadata.authType, request.headers);
|
const response = await logoutSS(authTypeMetadata.authType, request.headers);
|
||||||
if (!response || response.ok) {
|
|
||||||
|
if (response && !response.ok) {
|
||||||
|
return new Response(response.body, { status: response?.status });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete cookies only if cloud is enabled (jwt auth)
|
||||||
|
if (CLOUD_ENABLED) {
|
||||||
|
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
|
||||||
|
const cookieOptions = {
|
||||||
|
path: "/",
|
||||||
|
secure: process.env.NODE_ENV === "production",
|
||||||
|
httpOnly: true,
|
||||||
|
sameSite: "lax" as const,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Logout successful, delete cookies
|
||||||
|
const headers = new Headers();
|
||||||
|
|
||||||
|
cookiesToDelete.forEach((cookieName) => {
|
||||||
|
headers.append(
|
||||||
|
"Set-Cookie",
|
||||||
|
`${cookieName}=; Max-Age=0; ${Object.entries(cookieOptions)
|
||||||
|
.map(([key, value]) => `${key}=${value}`)
|
||||||
|
.join("; ")}`
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
return new Response(null, {
|
||||||
|
status: 204,
|
||||||
|
headers: headers,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
return new Response(null, { status: 204 });
|
return new Response(null, { status: 204 });
|
||||||
}
|
}
|
||||||
return new Response(response.body, { status: response?.status });
|
|
||||||
};
|
};
|
||||||
|
@ -58,6 +58,7 @@ export const DISABLE_LLM_DOC_RELEVANCE =
|
|||||||
|
|
||||||
export const CLOUD_ENABLED =
|
export const CLOUD_ENABLED =
|
||||||
process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true";
|
process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true";
|
||||||
|
|
||||||
export const REGISTRATION_URL =
|
export const REGISTRATION_URL =
|
||||||
process.env.INTERNAL_URL || "http://127.0.0.1:3001";
|
process.env.INTERNAL_URL || "http://127.0.0.1:3001";
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user