Permission Sync Framework (#44)

This commit is contained in:
Yuhong Sun 2024-04-20 20:34:20 -07:00 committed by Chris Weaver
parent 1984f2c1ca
commit 0c827d1e6c
11 changed files with 385 additions and 0 deletions

View File

@ -11,6 +11,14 @@ stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
# Automatic sync-ing of access to documents from sources
[program:permission_syncing]
command=python ee/danswer/background/permission_sync.py
stdout_logfile=/var/log/permission_sync.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
# Background jobs that must be run async due to long time to completion
[program:celery_worker]
command=celery -A ee.danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log

View File

@ -0,0 +1,221 @@
import logging
import time
from datetime import datetime
import dask
from dask.distributed import Client
from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy.orm import Session
from danswer.background.indexing.dask_utils import ResourceLogger
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.constants import DocumentSource
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import PermissionSyncStatus
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import NUM_PERMISSION_WORKERS
from ee.danswer.connectors.factory import CONNECTOR_PERMISSION_FUNC_MAP
from ee.danswer.db.connector import fetch_sources_with_connectors
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
from ee.danswer.db.permission_sync import create_perm_sync
from ee.danswer.db.permission_sync import expire_perm_sync_timed_out
from ee.danswer.db.permission_sync import get_perm_sync_attempt
from ee.danswer.db.permission_sync import mark_all_inprogress_permission_sync_failed
from shared_configs.configs import LOG_LEVEL
logger = setup_logger()
# If the indexing dies, it's most likely due to resource constraints,
# restarting just delays the eventual failure, not useful to the user
dask.config.set({"distributed.scheduler.allowed-failures": 0})
def cleanup_perm_sync_jobs(
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
# Just reusing the same timeout, fine for now
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
with Session(get_sqlalchemy_engine()) as db_session:
# clean up completed jobs
for (attempt_id, details), job in existing_jobs.items():
perm_sync_attempt = get_perm_sync_attempt(
attempt_id=attempt_id, db_session=db_session
)
# do nothing for ongoing jobs that haven't been stopped
if (
not job.done()
and perm_sync_attempt.status == PermissionSyncStatus.IN_PROGRESS
):
continue
if job.status == "error":
logger.error(job.exception())
job.release()
del existing_jobs_copy[(attempt_id, details)]
# clean up in-progress jobs that were never completed
expire_perm_sync_timed_out(
timeout_hours=timeout_hours,
db_session=db_session,
)
return existing_jobs_copy
def create_group_sync_jobs(
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
client: Client | SimpleJobClient,
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
"""Creates new relational DB group permission sync job for each source that:
- has permission sync enabled
- has at least 1 connector (enabled or paused)
- has no sync already running
"""
existing_jobs_copy = existing_jobs.copy()
sources_w_runs = [
key[1]
for key in existing_jobs_copy.keys()
if isinstance(key[1], DocumentSource)
]
with Session(get_sqlalchemy_engine()) as db_session:
sources_w_connector = fetch_sources_with_connectors(db_session)
for source_type in sources_w_connector:
if source_type not in CONNECTOR_PERMISSION_FUNC_MAP:
continue
if source_type in sources_w_runs:
continue
db_group_fnc, _ = CONNECTOR_PERMISSION_FUNC_MAP[source_type]
perm_sync = create_perm_sync(
source_type=source_type,
group_update=True,
cc_pair_id=None,
db_session=db_session,
)
run = client.submit(db_group_fnc, pure=False)
logger.info(
f"Kicked off group permission sync for source type {source_type}"
)
if run:
existing_jobs_copy[(perm_sync.id, source_type)] = run
return existing_jobs_copy
def create_connector_perm_sync_jobs(
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
client: Client | SimpleJobClient,
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
"""Update Document Index ACL sync job for each cc-pair where:
- source type has permission sync enabled
- has no sync already running
"""
existing_jobs_copy = existing_jobs.copy()
cc_pairs_w_runs = [
key[1]
for key in existing_jobs_copy.keys()
if isinstance(key[1], DocumentSource)
]
with Session(get_sqlalchemy_engine()) as db_session:
sources_w_connector = fetch_sources_with_connectors(db_session)
for source_type in sources_w_connector:
if source_type not in CONNECTOR_PERMISSION_FUNC_MAP:
continue
_, index_sync_fnc = CONNECTOR_PERMISSION_FUNC_MAP[source_type]
cc_pairs = get_cc_pairs_by_source(source_type, db_session)
for cc_pair in cc_pairs:
if cc_pair.id in cc_pairs_w_runs:
continue
perm_sync = create_perm_sync(
source_type=source_type,
group_update=False,
cc_pair_id=cc_pair.id,
db_session=db_session,
)
run = client.submit(index_sync_fnc, cc_pair.id, pure=False)
logger.info(f"Kicked off ACL sync for cc-pair {cc_pair.id}")
if run:
existing_jobs_copy[(perm_sync.id, cc_pair.id)] = run
return existing_jobs_copy
def permission_loop(delay: int = 60, num_workers: int = NUM_PERMISSION_WORKERS) -> None:
client: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED:
cluster_primary = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
# there are warning about high memory usage + "Event loop unresponsive"
# which are not relevant to us since our workers are expected to use a
# lot of memory + involve CPU intensive tasks that will not relinquish
# the event loop
silence_logs=logging.ERROR,
)
client = Client(cluster_primary)
if LOG_LEVEL.lower() == "debug":
client.register_worker_plugin(ResourceLogger())
else:
client = SimpleJobClient(n_workers=num_workers)
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob] = {}
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# Any jobs still in progress on restart must have died
mark_all_inprogress_permission_sync_failed(db_session)
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Running Permission Sync, current UTC time: {start_time_utc}")
if existing_jobs:
logger.debug(
"Found existing permission sync jobs: "
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
)
try:
# TODO turn this on when it works
"""
existing_jobs = cleanup_perm_sync_jobs(existing_jobs=existing_jobs)
existing_jobs = create_group_sync_jobs(
existing_jobs=existing_jobs, client=client
)
existing_jobs = create_connector_perm_sync_jobs(
existing_jobs=existing_jobs, client=client
)
"""
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)
def update__main() -> None:
logger.info("Starting Permission Syncing Loop")
permission_loop()
if __name__ == "__main__":
update__main()

View File

@ -15,3 +15,9 @@ _API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)
#####
# Auto Permission Sync
#####
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)

View File

@ -0,0 +1,12 @@
from danswer.utils.logger import setup_logger
logger = setup_logger()
def confluence_update_db_group() -> None:
logger.debug("Not yet implemented group sync for confluence, no-op")
def confluence_update_index_acl(cc_pair_id: int) -> None:
logger.debug("Not yet implemented ACL sync for confluence, no-op")

View File

@ -0,0 +1,8 @@
from danswer.configs.constants import DocumentSource
from ee.danswer.connectors.confluence.perm_sync import confluence_update_db_group
from ee.danswer.connectors.confluence.perm_sync import confluence_update_index_acl
CONNECTOR_PERMISSION_FUNC_MAP = {
DocumentSource.CONFLUENCE: (confluence_update_db_group, confluence_update_index_acl)
}

View File

@ -0,0 +1,16 @@
from sqlalchemy import distinct
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.models import Connector
from danswer.utils.logger import setup_logger
logger = setup_logger()
def fetch_sources_with_connectors(db_session: Session) -> list[DocumentSource]:
sources = db_session.query(distinct(Connector.source)).all() # type: ignore
document_sources = [source[0] for source in sources]
return document_sources

View File

@ -0,0 +1,22 @@
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_cc_pairs_by_source(
source_type: DocumentSource,
db_session: Session,
) -> list[ConnectorCredentialPair]:
cc_pairs = (
db_session.query(ConnectorCredentialPair)
.join(ConnectorCredentialPair.connector)
.filter(Connector.source == source_type)
.all()
)
return cc_pairs

View File

@ -0,0 +1,72 @@
from datetime import timedelta
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.models import PermissionSyncRun
from danswer.db.models import PermissionSyncStatus
from danswer.utils.logger import setup_logger
logger = setup_logger()
def mark_all_inprogress_permission_sync_failed(
db_session: Session,
) -> None:
stmt = (
update(PermissionSyncRun)
.where(PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS)
.values(status=PermissionSyncStatus.FAILED)
)
db_session.execute(stmt)
db_session.commit()
def get_perm_sync_attempt(attempt_id: int, db_session: Session) -> PermissionSyncRun:
stmt = select(PermissionSyncRun).where(PermissionSyncRun.id == attempt_id)
try:
return db_session.scalars(stmt).one()
except NoResultFound:
raise ValueError(f"No PermissionSyncRun found with id {attempt_id}")
def expire_perm_sync_timed_out(
timeout_hours: int,
db_session: Session,
) -> None:
cutoff_time = func.now() - timedelta(hours=timeout_hours)
update_stmt = (
update(PermissionSyncRun)
.where(
PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS,
PermissionSyncRun.updated_at < cutoff_time,
)
.values(status=PermissionSyncStatus.FAILED, error_msg="timed out")
)
db_session.execute(update_stmt)
db_session.commit()
def create_perm_sync(
source_type: DocumentSource,
group_update: bool,
cc_pair_id: int | None,
db_session: Session,
) -> PermissionSyncRun:
new_run = PermissionSyncRun(
source_type=source_type,
status=PermissionSyncStatus.IN_PROGRESS,
group_update=group_update,
cc_pair_id=cc_pair_id,
)
db_session.add(new_run)
db_session.commit()
return new_run

View File

@ -71,6 +71,26 @@ def run_jobs(exclude_indexing: bool) -> None:
indexing_thread.start()
indexing_thread.join()
try:
update_env = os.environ.copy()
update_env["PYTHONPATH"] = "."
cmd_perm_sync = ["python", "ee.danswer/background/permission_sync.py"]
indexing_process = subprocess.Popen(
cmd_perm_sync,
env=update_env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
perm_sync_thread = threading.Thread(
target=monitor_process, args=("INDEXING", indexing_process)
)
perm_sync_thread.start()
perm_sync_thread.join()
except Exception:
pass
worker_thread.join()
beat_thread.join()