mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-24 15:00:57 +02:00
Permission Sync Framework (#44)
This commit is contained in:
parent
1984f2c1ca
commit
0c827d1e6c
@ -11,6 +11,14 @@ stdout_logfile_maxbytes=52428800
|
|||||||
redirect_stderr=true
|
redirect_stderr=true
|
||||||
autorestart=true
|
autorestart=true
|
||||||
|
|
||||||
|
# Automatic sync-ing of access to documents from sources
|
||||||
|
[program:permission_syncing]
|
||||||
|
command=python ee/danswer/background/permission_sync.py
|
||||||
|
stdout_logfile=/var/log/permission_sync.log
|
||||||
|
stdout_logfile_maxbytes=52428800
|
||||||
|
redirect_stderr=true
|
||||||
|
autorestart=true
|
||||||
|
|
||||||
# Background jobs that must be run async due to long time to completion
|
# Background jobs that must be run async due to long time to completion
|
||||||
[program:celery_worker]
|
[program:celery_worker]
|
||||||
command=celery -A ee.danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log
|
command=celery -A ee.danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log
|
||||||
|
221
backend/ee/danswer/background/permission_sync.py
Normal file
221
backend/ee/danswer/background/permission_sync.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import dask
|
||||||
|
from dask.distributed import Client
|
||||||
|
from dask.distributed import Future
|
||||||
|
from distributed import LocalCluster
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.background.indexing.dask_utils import ResourceLogger
|
||||||
|
from danswer.background.indexing.job_client import SimpleJob
|
||||||
|
from danswer.background.indexing.job_client import SimpleJobClient
|
||||||
|
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||||
|
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||||
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.db.models import PermissionSyncStatus
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from ee.danswer.configs.app_configs import NUM_PERMISSION_WORKERS
|
||||||
|
from ee.danswer.connectors.factory import CONNECTOR_PERMISSION_FUNC_MAP
|
||||||
|
from ee.danswer.db.connector import fetch_sources_with_connectors
|
||||||
|
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
|
||||||
|
from ee.danswer.db.permission_sync import create_perm_sync
|
||||||
|
from ee.danswer.db.permission_sync import expire_perm_sync_timed_out
|
||||||
|
from ee.danswer.db.permission_sync import get_perm_sync_attempt
|
||||||
|
from ee.danswer.db.permission_sync import mark_all_inprogress_permission_sync_failed
|
||||||
|
from shared_configs.configs import LOG_LEVEL
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
# If the indexing dies, it's most likely due to resource constraints,
|
||||||
|
# restarting just delays the eventual failure, not useful to the user
|
||||||
|
dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_perm_sync_jobs(
|
||||||
|
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
|
||||||
|
# Just reusing the same timeout, fine for now
|
||||||
|
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||||
|
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
|
||||||
|
existing_jobs_copy = existing_jobs.copy()
|
||||||
|
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
# clean up completed jobs
|
||||||
|
for (attempt_id, details), job in existing_jobs.items():
|
||||||
|
perm_sync_attempt = get_perm_sync_attempt(
|
||||||
|
attempt_id=attempt_id, db_session=db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
# do nothing for ongoing jobs that haven't been stopped
|
||||||
|
if (
|
||||||
|
not job.done()
|
||||||
|
and perm_sync_attempt.status == PermissionSyncStatus.IN_PROGRESS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if job.status == "error":
|
||||||
|
logger.error(job.exception())
|
||||||
|
|
||||||
|
job.release()
|
||||||
|
del existing_jobs_copy[(attempt_id, details)]
|
||||||
|
|
||||||
|
# clean up in-progress jobs that were never completed
|
||||||
|
expire_perm_sync_timed_out(
|
||||||
|
timeout_hours=timeout_hours,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
return existing_jobs_copy
|
||||||
|
|
||||||
|
|
||||||
|
def create_group_sync_jobs(
|
||||||
|
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
|
||||||
|
client: Client | SimpleJobClient,
|
||||||
|
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
|
||||||
|
"""Creates new relational DB group permission sync job for each source that:
|
||||||
|
- has permission sync enabled
|
||||||
|
- has at least 1 connector (enabled or paused)
|
||||||
|
- has no sync already running
|
||||||
|
"""
|
||||||
|
existing_jobs_copy = existing_jobs.copy()
|
||||||
|
sources_w_runs = [
|
||||||
|
key[1]
|
||||||
|
for key in existing_jobs_copy.keys()
|
||||||
|
if isinstance(key[1], DocumentSource)
|
||||||
|
]
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
sources_w_connector = fetch_sources_with_connectors(db_session)
|
||||||
|
for source_type in sources_w_connector:
|
||||||
|
if source_type not in CONNECTOR_PERMISSION_FUNC_MAP:
|
||||||
|
continue
|
||||||
|
if source_type in sources_w_runs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
db_group_fnc, _ = CONNECTOR_PERMISSION_FUNC_MAP[source_type]
|
||||||
|
perm_sync = create_perm_sync(
|
||||||
|
source_type=source_type,
|
||||||
|
group_update=True,
|
||||||
|
cc_pair_id=None,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
run = client.submit(db_group_fnc, pure=False)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Kicked off group permission sync for source type {source_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if run:
|
||||||
|
existing_jobs_copy[(perm_sync.id, source_type)] = run
|
||||||
|
|
||||||
|
return existing_jobs_copy
|
||||||
|
|
||||||
|
|
||||||
|
def create_connector_perm_sync_jobs(
|
||||||
|
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
|
||||||
|
client: Client | SimpleJobClient,
|
||||||
|
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
|
||||||
|
"""Update Document Index ACL sync job for each cc-pair where:
|
||||||
|
- source type has permission sync enabled
|
||||||
|
- has no sync already running
|
||||||
|
"""
|
||||||
|
existing_jobs_copy = existing_jobs.copy()
|
||||||
|
cc_pairs_w_runs = [
|
||||||
|
key[1]
|
||||||
|
for key in existing_jobs_copy.keys()
|
||||||
|
if isinstance(key[1], DocumentSource)
|
||||||
|
]
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
sources_w_connector = fetch_sources_with_connectors(db_session)
|
||||||
|
for source_type in sources_w_connector:
|
||||||
|
if source_type not in CONNECTOR_PERMISSION_FUNC_MAP:
|
||||||
|
continue
|
||||||
|
|
||||||
|
_, index_sync_fnc = CONNECTOR_PERMISSION_FUNC_MAP[source_type]
|
||||||
|
|
||||||
|
cc_pairs = get_cc_pairs_by_source(source_type, db_session)
|
||||||
|
|
||||||
|
for cc_pair in cc_pairs:
|
||||||
|
if cc_pair.id in cc_pairs_w_runs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
perm_sync = create_perm_sync(
|
||||||
|
source_type=source_type,
|
||||||
|
group_update=False,
|
||||||
|
cc_pair_id=cc_pair.id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
run = client.submit(index_sync_fnc, cc_pair.id, pure=False)
|
||||||
|
|
||||||
|
logger.info(f"Kicked off ACL sync for cc-pair {cc_pair.id}")
|
||||||
|
|
||||||
|
if run:
|
||||||
|
existing_jobs_copy[(perm_sync.id, cc_pair.id)] = run
|
||||||
|
|
||||||
|
return existing_jobs_copy
|
||||||
|
|
||||||
|
|
||||||
|
def permission_loop(delay: int = 60, num_workers: int = NUM_PERMISSION_WORKERS) -> None:
|
||||||
|
client: Client | SimpleJobClient
|
||||||
|
if DASK_JOB_CLIENT_ENABLED:
|
||||||
|
cluster_primary = LocalCluster(
|
||||||
|
n_workers=num_workers,
|
||||||
|
threads_per_worker=1,
|
||||||
|
# there are warning about high memory usage + "Event loop unresponsive"
|
||||||
|
# which are not relevant to us since our workers are expected to use a
|
||||||
|
# lot of memory + involve CPU intensive tasks that will not relinquish
|
||||||
|
# the event loop
|
||||||
|
silence_logs=logging.ERROR,
|
||||||
|
)
|
||||||
|
client = Client(cluster_primary)
|
||||||
|
if LOG_LEVEL.lower() == "debug":
|
||||||
|
client.register_worker_plugin(ResourceLogger())
|
||||||
|
else:
|
||||||
|
client = SimpleJobClient(n_workers=num_workers)
|
||||||
|
|
||||||
|
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob] = {}
|
||||||
|
engine = get_sqlalchemy_engine()
|
||||||
|
|
||||||
|
with Session(engine) as db_session:
|
||||||
|
# Any jobs still in progress on restart must have died
|
||||||
|
mark_all_inprogress_permission_sync_failed(db_session)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
start = time.time()
|
||||||
|
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
logger.info(f"Running Permission Sync, current UTC time: {start_time_utc}")
|
||||||
|
|
||||||
|
if existing_jobs:
|
||||||
|
logger.debug(
|
||||||
|
"Found existing permission sync jobs: "
|
||||||
|
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO turn this on when it works
|
||||||
|
"""
|
||||||
|
existing_jobs = cleanup_perm_sync_jobs(existing_jobs=existing_jobs)
|
||||||
|
existing_jobs = create_group_sync_jobs(
|
||||||
|
existing_jobs=existing_jobs, client=client
|
||||||
|
)
|
||||||
|
existing_jobs = create_connector_perm_sync_jobs(
|
||||||
|
existing_jobs=existing_jobs, client=client
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to run update due to {e}")
|
||||||
|
sleep_time = delay - (time.time() - start)
|
||||||
|
if sleep_time > 0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
|
||||||
|
def update__main() -> None:
|
||||||
|
logger.info("Starting Permission Syncing Loop")
|
||||||
|
permission_loop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
update__main()
|
@ -15,3 +15,9 @@ _API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
|
|||||||
API_KEY_HASH_ROUNDS = (
|
API_KEY_HASH_ROUNDS = (
|
||||||
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
|
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#####
|
||||||
|
# Auto Permission Sync
|
||||||
|
#####
|
||||||
|
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||||
|
0
backend/ee/danswer/connectors/__init__.py
Normal file
0
backend/ee/danswer/connectors/__init__.py
Normal file
12
backend/ee/danswer/connectors/confluence/perm_sync.py
Normal file
12
backend/ee/danswer/connectors/confluence/perm_sync.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def confluence_update_db_group() -> None:
|
||||||
|
logger.debug("Not yet implemented group sync for confluence, no-op")
|
||||||
|
|
||||||
|
|
||||||
|
def confluence_update_index_acl(cc_pair_id: int) -> None:
|
||||||
|
logger.debug("Not yet implemented ACL sync for confluence, no-op")
|
8
backend/ee/danswer/connectors/factory.py
Normal file
8
backend/ee/danswer/connectors/factory.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from ee.danswer.connectors.confluence.perm_sync import confluence_update_db_group
|
||||||
|
from ee.danswer.connectors.confluence.perm_sync import confluence_update_index_acl
|
||||||
|
|
||||||
|
|
||||||
|
CONNECTOR_PERMISSION_FUNC_MAP = {
|
||||||
|
DocumentSource.CONFLUENCE: (confluence_update_db_group, confluence_update_index_acl)
|
||||||
|
}
|
16
backend/ee/danswer/db/connector.py
Normal file
16
backend/ee/danswer/db/connector.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from sqlalchemy import distinct
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.db.models import Connector
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_sources_with_connectors(db_session: Session) -> list[DocumentSource]:
|
||||||
|
sources = db_session.query(distinct(Connector.source)).all() # type: ignore
|
||||||
|
|
||||||
|
document_sources = [source[0] for source in sources]
|
||||||
|
|
||||||
|
return document_sources
|
22
backend/ee/danswer/db/connector_credential_pair.py
Normal file
22
backend/ee/danswer/db/connector_credential_pair.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.db.models import Connector
|
||||||
|
from danswer.db.models import ConnectorCredentialPair
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cc_pairs_by_source(
|
||||||
|
source_type: DocumentSource,
|
||||||
|
db_session: Session,
|
||||||
|
) -> list[ConnectorCredentialPair]:
|
||||||
|
cc_pairs = (
|
||||||
|
db_session.query(ConnectorCredentialPair)
|
||||||
|
.join(ConnectorCredentialPair.connector)
|
||||||
|
.filter(Connector.source == source_type)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return cc_pairs
|
72
backend/ee/danswer/db/permission_sync.py
Normal file
72
backend/ee/danswer/db/permission_sync.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy import update
|
||||||
|
from sqlalchemy.exc import NoResultFound
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.db.models import PermissionSyncRun
|
||||||
|
from danswer.db.models import PermissionSyncStatus
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def mark_all_inprogress_permission_sync_failed(
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
stmt = (
|
||||||
|
update(PermissionSyncRun)
|
||||||
|
.where(PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS)
|
||||||
|
.values(status=PermissionSyncStatus.FAILED)
|
||||||
|
)
|
||||||
|
db_session.execute(stmt)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def get_perm_sync_attempt(attempt_id: int, db_session: Session) -> PermissionSyncRun:
|
||||||
|
stmt = select(PermissionSyncRun).where(PermissionSyncRun.id == attempt_id)
|
||||||
|
try:
|
||||||
|
return db_session.scalars(stmt).one()
|
||||||
|
except NoResultFound:
|
||||||
|
raise ValueError(f"No PermissionSyncRun found with id {attempt_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def expire_perm_sync_timed_out(
|
||||||
|
timeout_hours: int,
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
cutoff_time = func.now() - timedelta(hours=timeout_hours)
|
||||||
|
|
||||||
|
update_stmt = (
|
||||||
|
update(PermissionSyncRun)
|
||||||
|
.where(
|
||||||
|
PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS,
|
||||||
|
PermissionSyncRun.updated_at < cutoff_time,
|
||||||
|
)
|
||||||
|
.values(status=PermissionSyncStatus.FAILED, error_msg="timed out")
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.execute(update_stmt)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def create_perm_sync(
|
||||||
|
source_type: DocumentSource,
|
||||||
|
group_update: bool,
|
||||||
|
cc_pair_id: int | None,
|
||||||
|
db_session: Session,
|
||||||
|
) -> PermissionSyncRun:
|
||||||
|
new_run = PermissionSyncRun(
|
||||||
|
source_type=source_type,
|
||||||
|
status=PermissionSyncStatus.IN_PROGRESS,
|
||||||
|
group_update=group_update,
|
||||||
|
cc_pair_id=cc_pair_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(new_run)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
return new_run
|
@ -71,6 +71,26 @@ def run_jobs(exclude_indexing: bool) -> None:
|
|||||||
|
|
||||||
indexing_thread.start()
|
indexing_thread.start()
|
||||||
indexing_thread.join()
|
indexing_thread.join()
|
||||||
|
try:
|
||||||
|
update_env = os.environ.copy()
|
||||||
|
update_env["PYTHONPATH"] = "."
|
||||||
|
cmd_perm_sync = ["python", "ee.danswer/background/permission_sync.py"]
|
||||||
|
|
||||||
|
indexing_process = subprocess.Popen(
|
||||||
|
cmd_perm_sync,
|
||||||
|
env=update_env,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
perm_sync_thread = threading.Thread(
|
||||||
|
target=monitor_process, args=("INDEXING", indexing_process)
|
||||||
|
)
|
||||||
|
perm_sync_thread.start()
|
||||||
|
perm_sync_thread.join()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
worker_thread.join()
|
worker_thread.join()
|
||||||
beat_thread.join()
|
beat_thread.join()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user