diff --git a/backend/ee.supervisord.conf b/backend/ee.supervisord.conf index 23788a3aa..583de3af0 100644 --- a/backend/ee.supervisord.conf +++ b/backend/ee.supervisord.conf @@ -11,6 +11,14 @@ stdout_logfile_maxbytes=52428800 redirect_stderr=true autorestart=true +# Automatic sync-ing of access to documents from sources +[program:permission_syncing] +command=python ee/danswer/background/permission_sync.py +stdout_logfile=/var/log/permission_sync.log +stdout_logfile_maxbytes=52428800 +redirect_stderr=true +autorestart=true + # Background jobs that must be run async due to long time to completion [program:celery_worker] command=celery -A ee.danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log diff --git a/backend/ee/danswer/background/permission_sync.py b/backend/ee/danswer/background/permission_sync.py new file mode 100644 index 000000000..b3e8845ab --- /dev/null +++ b/backend/ee/danswer/background/permission_sync.py @@ -0,0 +1,221 @@ +import logging +import time +from datetime import datetime + +import dask +from dask.distributed import Client +from dask.distributed import Future +from distributed import LocalCluster +from sqlalchemy.orm import Session + +from danswer.background.indexing.dask_utils import ResourceLogger +from danswer.background.indexing.job_client import SimpleJob +from danswer.background.indexing.job_client import SimpleJobClient +from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT +from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED +from danswer.configs.constants import DocumentSource +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import PermissionSyncStatus +from danswer.utils.logger import setup_logger +from ee.danswer.configs.app_configs import NUM_PERMISSION_WORKERS +from ee.danswer.connectors.factory import CONNECTOR_PERMISSION_FUNC_MAP +from ee.danswer.db.connector import fetch_sources_with_connectors +from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source +from ee.danswer.db.permission_sync import create_perm_sync +from ee.danswer.db.permission_sync import expire_perm_sync_timed_out +from ee.danswer.db.permission_sync import get_perm_sync_attempt +from ee.danswer.db.permission_sync import mark_all_inprogress_permission_sync_failed +from shared_configs.configs import LOG_LEVEL + +logger = setup_logger() + +# If the indexing dies, it's most likely due to resource constraints, +# restarting just delays the eventual failure, not useful to the user +dask.config.set({"distributed.scheduler.allowed-failures": 0}) + + +def cleanup_perm_sync_jobs( + existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], + # Just reusing the same timeout, fine for now + timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT, +) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: + existing_jobs_copy = existing_jobs.copy() + + with Session(get_sqlalchemy_engine()) as db_session: + # clean up completed jobs + for (attempt_id, details), job in existing_jobs.items(): + perm_sync_attempt = get_perm_sync_attempt( + attempt_id=attempt_id, db_session=db_session + ) + + # do nothing for ongoing jobs that haven't been stopped + if ( + not job.done() + and perm_sync_attempt.status == PermissionSyncStatus.IN_PROGRESS + ): + continue + + if job.status == "error": + logger.error(job.exception()) + + job.release() + del existing_jobs_copy[(attempt_id, details)] + + # clean up in-progress jobs that were never completed + expire_perm_sync_timed_out( + timeout_hours=timeout_hours, + db_session=db_session, + ) + + return existing_jobs_copy + + +def create_group_sync_jobs( + existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], + client: Client | SimpleJobClient, +) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: + """Creates new relational DB group permission sync job for each source that: + - has permission sync enabled + - has at least 1 connector (enabled or paused) + - has no sync already running + """ + existing_jobs_copy = existing_jobs.copy() + sources_w_runs = [ + key[1] + for key in existing_jobs_copy.keys() + if isinstance(key[1], DocumentSource) + ] + with Session(get_sqlalchemy_engine()) as db_session: + sources_w_connector = fetch_sources_with_connectors(db_session) + for source_type in sources_w_connector: + if source_type not in CONNECTOR_PERMISSION_FUNC_MAP: + continue + if source_type in sources_w_runs: + continue + + db_group_fnc, _ = CONNECTOR_PERMISSION_FUNC_MAP[source_type] + perm_sync = create_perm_sync( + source_type=source_type, + group_update=True, + cc_pair_id=None, + db_session=db_session, + ) + + run = client.submit(db_group_fnc, pure=False) + + logger.info( + f"Kicked off group permission sync for source type {source_type}" + ) + + if run: + existing_jobs_copy[(perm_sync.id, source_type)] = run + + return existing_jobs_copy + + +def create_connector_perm_sync_jobs( + existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], + client: Client | SimpleJobClient, +) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: + """Update Document Index ACL sync job for each cc-pair where: + - source type has permission sync enabled + - has no sync already running + """ + existing_jobs_copy = existing_jobs.copy() + cc_pairs_w_runs = [ + key[1] + for key in existing_jobs_copy.keys() + if isinstance(key[1], DocumentSource) + ] + with Session(get_sqlalchemy_engine()) as db_session: + sources_w_connector = fetch_sources_with_connectors(db_session) + for source_type in sources_w_connector: + if source_type not in CONNECTOR_PERMISSION_FUNC_MAP: + continue + + _, index_sync_fnc = CONNECTOR_PERMISSION_FUNC_MAP[source_type] + + cc_pairs = get_cc_pairs_by_source(source_type, db_session) + + for cc_pair in cc_pairs: + if cc_pair.id in cc_pairs_w_runs: + continue + + perm_sync = create_perm_sync( + source_type=source_type, + group_update=False, + cc_pair_id=cc_pair.id, + db_session=db_session, + ) + + run = client.submit(index_sync_fnc, cc_pair.id, pure=False) + + logger.info(f"Kicked off ACL sync for cc-pair {cc_pair.id}") + + if run: + existing_jobs_copy[(perm_sync.id, cc_pair.id)] = run + + return existing_jobs_copy + + +def permission_loop(delay: int = 60, num_workers: int = NUM_PERMISSION_WORKERS) -> None: + client: Client | SimpleJobClient + if DASK_JOB_CLIENT_ENABLED: + cluster_primary = LocalCluster( + n_workers=num_workers, + threads_per_worker=1, + # there are warning about high memory usage + "Event loop unresponsive" + # which are not relevant to us since our workers are expected to use a + # lot of memory + involve CPU intensive tasks that will not relinquish + # the event loop + silence_logs=logging.ERROR, + ) + client = Client(cluster_primary) + if LOG_LEVEL.lower() == "debug": + client.register_worker_plugin(ResourceLogger()) + else: + client = SimpleJobClient(n_workers=num_workers) + + existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob] = {} + engine = get_sqlalchemy_engine() + + with Session(engine) as db_session: + # Any jobs still in progress on restart must have died + mark_all_inprogress_permission_sync_failed(db_session) + + while True: + start = time.time() + start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S") + logger.info(f"Running Permission Sync, current UTC time: {start_time_utc}") + + if existing_jobs: + logger.debug( + "Found existing permission sync jobs: " + f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}" + ) + + try: + # TODO turn this on when it works + """ + existing_jobs = cleanup_perm_sync_jobs(existing_jobs=existing_jobs) + existing_jobs = create_group_sync_jobs( + existing_jobs=existing_jobs, client=client + ) + existing_jobs = create_connector_perm_sync_jobs( + existing_jobs=existing_jobs, client=client + ) + """ + except Exception as e: + logger.exception(f"Failed to run update due to {e}") + sleep_time = delay - (time.time() - start) + if sleep_time > 0: + time.sleep(sleep_time) + + +def update__main() -> None: + logger.info("Starting Permission Syncing Loop") + permission_loop() + + +if __name__ == "__main__": + update__main() diff --git a/backend/ee/danswer/configs/app_configs.py b/backend/ee/danswer/configs/app_configs.py index 6b576a14e..1430a4991 100644 --- a/backend/ee/danswer/configs/app_configs.py +++ b/backend/ee/danswer/configs/app_configs.py @@ -15,3 +15,9 @@ _API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS") API_KEY_HASH_ROUNDS = ( int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None ) + + +##### +# Auto Permission Sync +##### +NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2) diff --git a/backend/ee/danswer/connectors/__init__.py b/backend/ee/danswer/connectors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/ee/danswer/connectors/confluence/__init__.py b/backend/ee/danswer/connectors/confluence/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/ee/danswer/connectors/confluence/perm_sync.py b/backend/ee/danswer/connectors/confluence/perm_sync.py new file mode 100644 index 000000000..2985b47b0 --- /dev/null +++ b/backend/ee/danswer/connectors/confluence/perm_sync.py @@ -0,0 +1,12 @@ +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +def confluence_update_db_group() -> None: + logger.debug("Not yet implemented group sync for confluence, no-op") + + +def confluence_update_index_acl(cc_pair_id: int) -> None: + logger.debug("Not yet implemented ACL sync for confluence, no-op") diff --git a/backend/ee/danswer/connectors/factory.py b/backend/ee/danswer/connectors/factory.py new file mode 100644 index 000000000..52f932494 --- /dev/null +++ b/backend/ee/danswer/connectors/factory.py @@ -0,0 +1,8 @@ +from danswer.configs.constants import DocumentSource +from ee.danswer.connectors.confluence.perm_sync import confluence_update_db_group +from ee.danswer.connectors.confluence.perm_sync import confluence_update_index_acl + + +CONNECTOR_PERMISSION_FUNC_MAP = { + DocumentSource.CONFLUENCE: (confluence_update_db_group, confluence_update_index_acl) +} diff --git a/backend/ee/danswer/db/connector.py b/backend/ee/danswer/db/connector.py new file mode 100644 index 000000000..44505f515 --- /dev/null +++ b/backend/ee/danswer/db/connector.py @@ -0,0 +1,16 @@ +from sqlalchemy import distinct +from sqlalchemy.orm import Session + +from danswer.configs.constants import DocumentSource +from danswer.db.models import Connector +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def fetch_sources_with_connectors(db_session: Session) -> list[DocumentSource]: + sources = db_session.query(distinct(Connector.source)).all() # type: ignore + + document_sources = [source[0] for source in sources] + + return document_sources diff --git a/backend/ee/danswer/db/connector_credential_pair.py b/backend/ee/danswer/db/connector_credential_pair.py new file mode 100644 index 000000000..a49381385 --- /dev/null +++ b/backend/ee/danswer/db/connector_credential_pair.py @@ -0,0 +1,22 @@ +from sqlalchemy.orm import Session + +from danswer.configs.constants import DocumentSource +from danswer.db.models import Connector +from danswer.db.models import ConnectorCredentialPair +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def get_cc_pairs_by_source( + source_type: DocumentSource, + db_session: Session, +) -> list[ConnectorCredentialPair]: + cc_pairs = ( + db_session.query(ConnectorCredentialPair) + .join(ConnectorCredentialPair.connector) + .filter(Connector.source == source_type) + .all() + ) + + return cc_pairs diff --git a/backend/ee/danswer/db/permission_sync.py b/backend/ee/danswer/db/permission_sync.py new file mode 100644 index 000000000..7642bb653 --- /dev/null +++ b/backend/ee/danswer/db/permission_sync.py @@ -0,0 +1,72 @@ +from datetime import timedelta + +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy import update +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import Session + +from danswer.configs.constants import DocumentSource +from danswer.db.models import PermissionSyncRun +from danswer.db.models import PermissionSyncStatus +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def mark_all_inprogress_permission_sync_failed( + db_session: Session, +) -> None: + stmt = ( + update(PermissionSyncRun) + .where(PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS) + .values(status=PermissionSyncStatus.FAILED) + ) + db_session.execute(stmt) + db_session.commit() + + +def get_perm_sync_attempt(attempt_id: int, db_session: Session) -> PermissionSyncRun: + stmt = select(PermissionSyncRun).where(PermissionSyncRun.id == attempt_id) + try: + return db_session.scalars(stmt).one() + except NoResultFound: + raise ValueError(f"No PermissionSyncRun found with id {attempt_id}") + + +def expire_perm_sync_timed_out( + timeout_hours: int, + db_session: Session, +) -> None: + cutoff_time = func.now() - timedelta(hours=timeout_hours) + + update_stmt = ( + update(PermissionSyncRun) + .where( + PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS, + PermissionSyncRun.updated_at < cutoff_time, + ) + .values(status=PermissionSyncStatus.FAILED, error_msg="timed out") + ) + + db_session.execute(update_stmt) + db_session.commit() + + +def create_perm_sync( + source_type: DocumentSource, + group_update: bool, + cc_pair_id: int | None, + db_session: Session, +) -> PermissionSyncRun: + new_run = PermissionSyncRun( + source_type=source_type, + status=PermissionSyncStatus.IN_PROGRESS, + group_update=group_update, + cc_pair_id=cc_pair_id, + ) + + db_session.add(new_run) + db_session.commit() + + return new_run diff --git a/backend/scripts/dev_run_background_jobs.py b/backend/scripts/dev_run_background_jobs.py index fbd322619..adbb5d220 100644 --- a/backend/scripts/dev_run_background_jobs.py +++ b/backend/scripts/dev_run_background_jobs.py @@ -71,6 +71,26 @@ def run_jobs(exclude_indexing: bool) -> None: indexing_thread.start() indexing_thread.join() + try: + update_env = os.environ.copy() + update_env["PYTHONPATH"] = "." + cmd_perm_sync = ["python", "ee.danswer/background/permission_sync.py"] + + indexing_process = subprocess.Popen( + cmd_perm_sync, + env=update_env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + perm_sync_thread = threading.Thread( + target=monitor_process, args=("INDEXING", indexing_process) + ) + perm_sync_thread.start() + perm_sync_thread.join() + except Exception: + pass worker_thread.join() beat_thread.join()