mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Permission Sync Framework (#44)
This commit is contained in:
parent
1984f2c1ca
commit
0c827d1e6c
@ -11,6 +11,14 @@ stdout_logfile_maxbytes=52428800
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
# Automatic sync-ing of access to documents from sources
|
||||
[program:permission_syncing]
|
||||
command=python ee/danswer/background/permission_sync.py
|
||||
stdout_logfile=/var/log/permission_sync.log
|
||||
stdout_logfile_maxbytes=52428800
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
# Background jobs that must be run async due to long time to completion
|
||||
[program:celery_worker]
|
||||
command=celery -A ee.danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log
|
||||
|
221
backend/ee/danswer/background/permission_sync.py
Normal file
221
backend/ee/danswer/background/permission_sync.py
Normal file
@ -0,0 +1,221 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import dask
|
||||
from dask.distributed import Client
|
||||
from dask.distributed import Future
|
||||
from distributed import LocalCluster
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
from danswer.background.indexing.job_client import SimpleJob
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import PermissionSyncStatus
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import NUM_PERMISSION_WORKERS
|
||||
from ee.danswer.connectors.factory import CONNECTOR_PERMISSION_FUNC_MAP
|
||||
from ee.danswer.db.connector import fetch_sources_with_connectors
|
||||
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.danswer.db.permission_sync import create_perm_sync
|
||||
from ee.danswer.db.permission_sync import expire_perm_sync_timed_out
|
||||
from ee.danswer.db.permission_sync import get_perm_sync_attempt
|
||||
from ee.danswer.db.permission_sync import mark_all_inprogress_permission_sync_failed
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If the indexing dies, it's most likely due to resource constraints,
|
||||
# restarting just delays the eventual failure, not useful to the user
|
||||
dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||
|
||||
|
||||
def cleanup_perm_sync_jobs(
|
||||
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
|
||||
# Just reusing the same timeout, fine for now
|
||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# clean up completed jobs
|
||||
for (attempt_id, details), job in existing_jobs.items():
|
||||
perm_sync_attempt = get_perm_sync_attempt(
|
||||
attempt_id=attempt_id, db_session=db_session
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if (
|
||||
not job.done()
|
||||
and perm_sync_attempt.status == PermissionSyncStatus.IN_PROGRESS
|
||||
):
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
del existing_jobs_copy[(attempt_id, details)]
|
||||
|
||||
# clean up in-progress jobs that were never completed
|
||||
expire_perm_sync_timed_out(
|
||||
timeout_hours=timeout_hours,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def create_group_sync_jobs(
|
||||
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
|
||||
client: Client | SimpleJobClient,
|
||||
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
|
||||
"""Creates new relational DB group permission sync job for each source that:
|
||||
- has permission sync enabled
|
||||
- has at least 1 connector (enabled or paused)
|
||||
- has no sync already running
|
||||
"""
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
sources_w_runs = [
|
||||
key[1]
|
||||
for key in existing_jobs_copy.keys()
|
||||
if isinstance(key[1], DocumentSource)
|
||||
]
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
sources_w_connector = fetch_sources_with_connectors(db_session)
|
||||
for source_type in sources_w_connector:
|
||||
if source_type not in CONNECTOR_PERMISSION_FUNC_MAP:
|
||||
continue
|
||||
if source_type in sources_w_runs:
|
||||
continue
|
||||
|
||||
db_group_fnc, _ = CONNECTOR_PERMISSION_FUNC_MAP[source_type]
|
||||
perm_sync = create_perm_sync(
|
||||
source_type=source_type,
|
||||
group_update=True,
|
||||
cc_pair_id=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
run = client.submit(db_group_fnc, pure=False)
|
||||
|
||||
logger.info(
|
||||
f"Kicked off group permission sync for source type {source_type}"
|
||||
)
|
||||
|
||||
if run:
|
||||
existing_jobs_copy[(perm_sync.id, source_type)] = run
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def create_connector_perm_sync_jobs(
|
||||
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
|
||||
client: Client | SimpleJobClient,
|
||||
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
|
||||
"""Update Document Index ACL sync job for each cc-pair where:
|
||||
- source type has permission sync enabled
|
||||
- has no sync already running
|
||||
"""
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
cc_pairs_w_runs = [
|
||||
key[1]
|
||||
for key in existing_jobs_copy.keys()
|
||||
if isinstance(key[1], DocumentSource)
|
||||
]
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
sources_w_connector = fetch_sources_with_connectors(db_session)
|
||||
for source_type in sources_w_connector:
|
||||
if source_type not in CONNECTOR_PERMISSION_FUNC_MAP:
|
||||
continue
|
||||
|
||||
_, index_sync_fnc = CONNECTOR_PERMISSION_FUNC_MAP[source_type]
|
||||
|
||||
cc_pairs = get_cc_pairs_by_source(source_type, db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if cc_pair.id in cc_pairs_w_runs:
|
||||
continue
|
||||
|
||||
perm_sync = create_perm_sync(
|
||||
source_type=source_type,
|
||||
group_update=False,
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
run = client.submit(index_sync_fnc, cc_pair.id, pure=False)
|
||||
|
||||
logger.info(f"Kicked off ACL sync for cc-pair {cc_pair.id}")
|
||||
|
||||
if run:
|
||||
existing_jobs_copy[(perm_sync.id, cc_pair.id)] = run
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def permission_loop(delay: int = 60, num_workers: int = NUM_PERMISSION_WORKERS) -> None:
|
||||
client: Client | SimpleJobClient
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
cluster_primary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
# there are warning about high memory usage + "Event loop unresponsive"
|
||||
# which are not relevant to us since our workers are expected to use a
|
||||
# lot of memory + involve CPU intensive tasks that will not relinquish
|
||||
# the event loop
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
client = Client(cluster_primary)
|
||||
if LOG_LEVEL.lower() == "debug":
|
||||
client.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client = SimpleJobClient(n_workers=num_workers)
|
||||
|
||||
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob] = {}
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
with Session(engine) as db_session:
|
||||
# Any jobs still in progress on restart must have died
|
||||
mark_all_inprogress_permission_sync_failed(db_session)
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info(f"Running Permission Sync, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
logger.debug(
|
||||
"Found existing permission sync jobs: "
|
||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||
)
|
||||
|
||||
try:
|
||||
# TODO turn this on when it works
|
||||
"""
|
||||
existing_jobs = cleanup_perm_sync_jobs(existing_jobs=existing_jobs)
|
||||
existing_jobs = create_group_sync_jobs(
|
||||
existing_jobs=existing_jobs, client=client
|
||||
)
|
||||
existing_jobs = create_connector_perm_sync_jobs(
|
||||
existing_jobs=existing_jobs, client=client
|
||||
)
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
sleep_time = delay - (time.time() - start)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
logger.info("Starting Permission Syncing Loop")
|
||||
permission_loop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update__main()
|
@ -15,3 +15,9 @@ _API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
|
||||
API_KEY_HASH_ROUNDS = (
|
||||
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Auto Permission Sync
|
||||
#####
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
0
backend/ee/danswer/connectors/__init__.py
Normal file
0
backend/ee/danswer/connectors/__init__.py
Normal file
12
backend/ee/danswer/connectors/confluence/perm_sync.py
Normal file
12
backend/ee/danswer/connectors/confluence/perm_sync.py
Normal file
@ -0,0 +1,12 @@
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def confluence_update_db_group() -> None:
|
||||
logger.debug("Not yet implemented group sync for confluence, no-op")
|
||||
|
||||
|
||||
def confluence_update_index_acl(cc_pair_id: int) -> None:
|
||||
logger.debug("Not yet implemented ACL sync for confluence, no-op")
|
8
backend/ee/danswer/connectors/factory.py
Normal file
8
backend/ee/danswer/connectors/factory.py
Normal file
@ -0,0 +1,8 @@
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from ee.danswer.connectors.confluence.perm_sync import confluence_update_db_group
|
||||
from ee.danswer.connectors.confluence.perm_sync import confluence_update_index_acl
|
||||
|
||||
|
||||
CONNECTOR_PERMISSION_FUNC_MAP = {
|
||||
DocumentSource.CONFLUENCE: (confluence_update_db_group, confluence_update_index_acl)
|
||||
}
|
16
backend/ee/danswer/db/connector.py
Normal file
16
backend/ee/danswer/db/connector.py
Normal file
@ -0,0 +1,16 @@
|
||||
from sqlalchemy import distinct
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import Connector
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_sources_with_connectors(db_session: Session) -> list[DocumentSource]:
|
||||
sources = db_session.query(distinct(Connector.source)).all() # type: ignore
|
||||
|
||||
document_sources = [source[0] for source in sources]
|
||||
|
||||
return document_sources
|
22
backend/ee/danswer/db/connector_credential_pair.py
Normal file
22
backend/ee/danswer/db/connector_credential_pair.py
Normal file
@ -0,0 +1,22 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_cc_pairs_by_source(
|
||||
source_type: DocumentSource,
|
||||
db_session: Session,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
cc_pairs = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.join(ConnectorCredentialPair.connector)
|
||||
.filter(Connector.source == source_type)
|
||||
.all()
|
||||
)
|
||||
|
||||
return cc_pairs
|
72
backend/ee/danswer/db/permission_sync.py
Normal file
72
backend/ee/danswer/db/permission_sync.py
Normal file
@ -0,0 +1,72 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import PermissionSyncRun
|
||||
from danswer.db.models import PermissionSyncStatus
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def mark_all_inprogress_permission_sync_failed(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(PermissionSyncRun)
|
||||
.where(PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS)
|
||||
.values(status=PermissionSyncStatus.FAILED)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_perm_sync_attempt(attempt_id: int, db_session: Session) -> PermissionSyncRun:
|
||||
stmt = select(PermissionSyncRun).where(PermissionSyncRun.id == attempt_id)
|
||||
try:
|
||||
return db_session.scalars(stmt).one()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"No PermissionSyncRun found with id {attempt_id}")
|
||||
|
||||
|
||||
def expire_perm_sync_timed_out(
|
||||
timeout_hours: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
cutoff_time = func.now() - timedelta(hours=timeout_hours)
|
||||
|
||||
update_stmt = (
|
||||
update(PermissionSyncRun)
|
||||
.where(
|
||||
PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS,
|
||||
PermissionSyncRun.updated_at < cutoff_time,
|
||||
)
|
||||
.values(status=PermissionSyncStatus.FAILED, error_msg="timed out")
|
||||
)
|
||||
|
||||
db_session.execute(update_stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_perm_sync(
|
||||
source_type: DocumentSource,
|
||||
group_update: bool,
|
||||
cc_pair_id: int | None,
|
||||
db_session: Session,
|
||||
) -> PermissionSyncRun:
|
||||
new_run = PermissionSyncRun(
|
||||
source_type=source_type,
|
||||
status=PermissionSyncStatus.IN_PROGRESS,
|
||||
group_update=group_update,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
db_session.add(new_run)
|
||||
db_session.commit()
|
||||
|
||||
return new_run
|
@ -71,6 +71,26 @@ def run_jobs(exclude_indexing: bool) -> None:
|
||||
|
||||
indexing_thread.start()
|
||||
indexing_thread.join()
|
||||
try:
|
||||
update_env = os.environ.copy()
|
||||
update_env["PYTHONPATH"] = "."
|
||||
cmd_perm_sync = ["python", "ee.danswer/background/permission_sync.py"]
|
||||
|
||||
indexing_process = subprocess.Popen(
|
||||
cmd_perm_sync,
|
||||
env=update_env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
perm_sync_thread = threading.Thread(
|
||||
target=monitor_process, args=("INDEXING", indexing_process)
|
||||
)
|
||||
perm_sync_thread.start()
|
||||
perm_sync_thread.join()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
worker_thread.join()
|
||||
beat_thread.join()
|
||||
|
Loading…
x
Reference in New Issue
Block a user