mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Added confluence permission syncing (#2537)
* Added confluence permission syncing * seperated out group and doc syncing * minorbugfix and mypy * added frontend and fixed bug * Minor refactor * dealth with confluence rate limits! * mypy fixes!!! * addressed yuhong feedback * primary key fix
This commit is contained in:
@@ -11,12 +11,25 @@ from danswer.server.settings.store import load_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.danswer.background.celery_utils import should_perform_external_permissions_check
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_doc_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_group_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_permission_sync_entrypoint,
|
||||
run_external_doc_permission_sync,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_external_group_permission_sync,
|
||||
)
|
||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||
|
||||
@@ -26,11 +39,18 @@ logger = setup_logger()
|
||||
global_version.set_ee()
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_permissions_task)
|
||||
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_permissions_task(cc_pair_id: int) -> None:
|
||||
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@@ -44,18 +64,35 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_sync_external_permissions_task",
|
||||
name="check_sync_external_doc_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_permissions_task() -> None:
|
||||
def check_sync_external_doc_permissions_task() -> None:
|
||||
"""Runs periodically to sync external permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_permissions_check(
|
||||
if should_perform_external_doc_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_permissions_task.apply_async(
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_sync_external_group_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_group_permissions_task() -> None:
|
||||
"""Runs periodically to sync external group permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_group_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_group_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
)
|
||||
|
||||
@@ -94,9 +131,13 @@ def autogenerate_usage_report_task() -> None:
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"sync-external-permissions": {
|
||||
"task": "check_sync_external_permissions_task",
|
||||
"schedule": timedelta(seconds=60), # TODO: optimize this
|
||||
"sync-external-doc-permissions": {
|
||||
"task": "check_sync_external_doc_permissions_task",
|
||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||
},
|
||||
"sync-external-group-permissions": {
|
||||
"task": "check_sync_external_group_permissions_task",
|
||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||
},
|
||||
"autogenerate_usage_report": {
|
||||
"task": "autogenerate_usage_report_task",
|
||||
|
@@ -12,7 +12,12 @@ from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
from ee.danswer.db.user_group import mark_user_group_as_synced
|
||||
@@ -38,13 +43,32 @@ def should_perform_chat_ttl_check(
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_external_permissions_check(
|
||||
def should_perform_external_doc_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
task_name = name_sync_external_permissions_task(cc_pair_id=cc_pair.id)
|
||||
task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
return True
|
||||
|
||||
if check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_external_group_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
|
@@ -2,5 +2,9 @@ def name_chat_ttl_task(retention_limit_days: int) -> str:
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
def name_sync_external_permissions_task(cc_pair_id: int) -> str:
|
||||
return f"sync_external_permissions_task__{cc_pair_id}"
|
||||
def name_sync_external_doc_permissions_task(cc_pair_id: int) -> str:
|
||||
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
||||
|
||||
|
||||
def name_sync_external_group_permissions_task(cc_pair_id: int) -> str:
|
||||
return f"sync_external_group_permissions_task__{cc_pair_id}"
|
||||
|
@@ -0,0 +1,18 @@
|
||||
from typing import Any
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
|
||||
|
||||
def build_confluence_client(
|
||||
connector_specific_config: dict[str, Any], raw_credentials_json: dict[str, Any]
|
||||
) -> Confluence:
|
||||
is_cloud = connector_specific_config.get("is_cloud", False)
|
||||
return Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=connector_specific_config["wiki_base"].rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=raw_credentials_json["confluence_username"] if is_cloud else None,
|
||||
password=raw_credentials_json["confluence_access_token"] if is_cloud else None,
|
||||
token=raw_credentials_json["confluence_access_token"] if not is_cloud else None,
|
||||
)
|
@@ -1,19 +1,254 @@
|
||||
from typing import Any
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.confluence.confluence_utils import (
|
||||
build_confluence_document_id,
|
||||
)
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.confluence.confluence_sync_utils import (
|
||||
build_confluence_client,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_REQUEST_PAGINATION_LIMIT = 100
|
||||
|
||||
|
||||
def _get_space_permissions(
|
||||
db_session: Session,
|
||||
confluence_client: Confluence,
|
||||
space_id: str,
|
||||
) -> ExternalAccess:
|
||||
get_space_permissions = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_space_permissions
|
||||
)
|
||||
|
||||
space_permissions = get_space_permissions(space_id).get("permissions", [])
|
||||
user_emails = set()
|
||||
# Confluence enforces that group names are unique
|
||||
group_names = set()
|
||||
is_externally_public = False
|
||||
for permission in space_permissions:
|
||||
subs = permission.get("subjects")
|
||||
if subs:
|
||||
# If there are subjects, then there are explicit users or groups with access
|
||||
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
|
||||
user_emails.add(email)
|
||||
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
|
||||
group_names.add(group_name)
|
||||
else:
|
||||
# If there are no subjects, then the permission is for everyone
|
||||
if permission.get("operation", {}).get(
|
||||
"operation"
|
||||
) == "read" and permission.get("anonymousAccess", False):
|
||||
# If the permission specifies read access for anonymous users, then
|
||||
# the space is publicly accessible
|
||||
is_externally_public = True
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=list(user_emails)
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_externally_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_restrictions_for_page(
|
||||
db_session: Session,
|
||||
page: dict[str, Any],
|
||||
space_permissions: ExternalAccess,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
WARNING: This function includes no pagination. So if a page is private within
|
||||
the space and has over 200 users or over 200 groups with explicitly read access,
|
||||
this function will leave out some users or groups.
|
||||
200 is a large amount so it is unlikely, but just be aware.
|
||||
"""
|
||||
restrictions_json = page.get("restrictions", {})
|
||||
read_access_dict = restrictions_json.get("read", {}).get("restrictions", {})
|
||||
|
||||
read_access_user_jsons = read_access_dict.get("user", {}).get("results", [])
|
||||
read_access_group_jsons = read_access_dict.get("group", {}).get("results", [])
|
||||
|
||||
is_space_public = read_access_user_jsons == [] and read_access_group_jsons == []
|
||||
|
||||
if not is_space_public:
|
||||
read_access_user_emails = [
|
||||
user["email"] for user in read_access_user_jsons if user.get("email")
|
||||
]
|
||||
read_access_groups = [group["name"] for group in read_access_group_jsons]
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=list(read_access_user_emails)
|
||||
)
|
||||
external_access = ExternalAccess(
|
||||
external_user_emails=set(read_access_user_emails),
|
||||
external_user_group_ids=set(read_access_groups),
|
||||
is_public=False,
|
||||
)
|
||||
else:
|
||||
external_access = space_permissions
|
||||
|
||||
return external_access
|
||||
|
||||
|
||||
def _fetch_attachment_document_ids_for_page_paginated(
|
||||
confluence_client: Confluence, page: dict[str, Any]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Starts by just extracting the first page of attachments from
|
||||
the page. If all attachments are in the first page, then
|
||||
no calls to the api are made from this function.
|
||||
"""
|
||||
get_attachments_from_content = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_attachments_from_content
|
||||
)
|
||||
|
||||
attachment_doc_ids = []
|
||||
attachments_dict = page["children"]["attachment"]
|
||||
start = 0
|
||||
|
||||
while True:
|
||||
attachments_list = attachments_dict["results"]
|
||||
attachment_doc_ids.extend(
|
||||
[
|
||||
build_confluence_document_id(
|
||||
base_url=confluence_client.url,
|
||||
content_url=attachment["_links"]["download"],
|
||||
)
|
||||
for attachment in attachments_list
|
||||
]
|
||||
)
|
||||
|
||||
if "next" not in attachments_dict["_links"]:
|
||||
break
|
||||
|
||||
start += len(attachments_list)
|
||||
attachments_dict = get_attachments_from_content(
|
||||
page_id=page["id"],
|
||||
start=start,
|
||||
limit=_REQUEST_PAGINATION_LIMIT,
|
||||
)
|
||||
|
||||
return attachment_doc_ids
|
||||
|
||||
|
||||
def _fetch_all_pages_paginated(
|
||||
confluence_client: Confluence,
|
||||
space_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_pages_from_space
|
||||
)
|
||||
|
||||
# For each page, this fetches the page's attachments and restrictions.
|
||||
expansion_strings = [
|
||||
"children.attachment",
|
||||
"restrictions.read.restrictions.user",
|
||||
"restrictions.read.restrictions.group",
|
||||
]
|
||||
expansion_string = ",".join(expansion_strings)
|
||||
|
||||
all_pages = []
|
||||
start = 0
|
||||
while True:
|
||||
pages_dict = get_all_pages_from_space(
|
||||
space=space_id,
|
||||
start=start,
|
||||
limit=_REQUEST_PAGINATION_LIMIT,
|
||||
expand=expansion_string,
|
||||
)
|
||||
all_pages.extend(pages_dict)
|
||||
|
||||
response_size = len(pages_dict)
|
||||
if response_size < _REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
start += response_size
|
||||
|
||||
return all_pages
|
||||
|
||||
|
||||
def _fetch_all_page_restrictions_for_space(
|
||||
db_session: Session,
|
||||
confluence_client: Confluence,
|
||||
space_id: str,
|
||||
space_permissions: ExternalAccess,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
all_pages = _fetch_all_pages_paginated(
|
||||
confluence_client=confluence_client,
|
||||
space_id=space_id,
|
||||
)
|
||||
|
||||
document_restrictions: dict[str, ExternalAccess] = {}
|
||||
for page in all_pages:
|
||||
"""
|
||||
This assigns the same permissions to all attachments of a page and
|
||||
the page itself.
|
||||
This is because the attachments are stored in the same Confluence space as the page.
|
||||
WARNING: We create a dbDocument entry for all attachments, even though attachments
|
||||
may not be their own standalone documents. This is likely fine as we just upsert a
|
||||
document with just permissions.
|
||||
"""
|
||||
attachment_document_ids = [
|
||||
build_confluence_document_id(
|
||||
base_url=confluence_client.url,
|
||||
content_url=page["_links"]["webui"],
|
||||
)
|
||||
]
|
||||
attachment_document_ids.extend(
|
||||
_fetch_attachment_document_ids_for_page_paginated(
|
||||
confluence_client=confluence_client, page=page
|
||||
)
|
||||
)
|
||||
page_permissions = _get_restrictions_for_page(
|
||||
db_session=db_session,
|
||||
page=page,
|
||||
space_permissions=space_permissions,
|
||||
)
|
||||
for attachment_document_id in attachment_document_ids:
|
||||
document_restrictions[attachment_document_id] = page_permissions
|
||||
|
||||
return document_restrictions
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug("Not yet implemented ACL sync for confluence, no-op")
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
confluence_client = build_confluence_client(
|
||||
cc_pair.connector.connector_specific_config, cc_pair.credential.credential_json
|
||||
)
|
||||
space_permissions = _get_space_permissions(
|
||||
db_session=db_session,
|
||||
confluence_client=confluence_client,
|
||||
space_id=cc_pair.connector.connector_specific_config["space"],
|
||||
)
|
||||
fresh_doc_permissions = _fetch_all_page_restrictions_for_space(
|
||||
db_session=db_session,
|
||||
confluence_client=confluence_client,
|
||||
space_id=cc_pair.connector.connector_specific_config["space"],
|
||||
space_permissions=space_permissions,
|
||||
)
|
||||
for doc_id, ext_access in fresh_doc_permissions.items():
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
|
@@ -1,19 +1,107 @@
|
||||
from typing import Any
|
||||
from collections.abc import Iterator
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.confluence.confluence_sync_utils import (
|
||||
build_confluence_client,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_PAGE_SIZE = 100
|
||||
|
||||
|
||||
def _get_confluence_group_names_paginated(
|
||||
confluence_client: Confluence,
|
||||
) -> Iterator[str]:
|
||||
get_all_groups = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_groups
|
||||
)
|
||||
|
||||
start = 0
|
||||
while True:
|
||||
try:
|
||||
groups = get_all_groups(start=start, limit=_PAGE_SIZE)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code in (403, 404):
|
||||
return
|
||||
raise e
|
||||
|
||||
for group in groups:
|
||||
if group_name := group.get("name"):
|
||||
yield group_name
|
||||
|
||||
if len(groups) < _PAGE_SIZE:
|
||||
break
|
||||
start += _PAGE_SIZE
|
||||
|
||||
|
||||
def _get_group_members_email_paginated(
|
||||
confluence_client: Confluence,
|
||||
group_name: str,
|
||||
) -> list[str]:
|
||||
get_group_members = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_group_members
|
||||
)
|
||||
group_member_emails: list[str] = []
|
||||
start = 0
|
||||
while True:
|
||||
try:
|
||||
members = get_group_members(
|
||||
group_name=group_name, start=start, limit=_PAGE_SIZE
|
||||
)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403 or e.response.status_code == 404:
|
||||
return group_member_emails
|
||||
raise e
|
||||
|
||||
group_member_emails.extend(
|
||||
[member.get("email") for member in members if member.get("email")]
|
||||
)
|
||||
if len(members) < _PAGE_SIZE:
|
||||
break
|
||||
start += _PAGE_SIZE
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug("Not yet implemented group sync for confluence, no-op")
|
||||
confluence_client = build_confluence_client(
|
||||
cc_pair.connector.connector_specific_config, cc_pair.credential.credential_json
|
||||
)
|
||||
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
# Confluence enforces that group names are unique
|
||||
for group_name in _get_confluence_group_names_paginated(confluence_client):
|
||||
group_member_emails = _get_group_members_email_paginated(
|
||||
confluence_client, group_name
|
||||
)
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=group_member_emails
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name, user_ids=[user.id for user in group_members]
|
||||
)
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
|
@@ -1,4 +1,6 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -8,15 +10,17 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
@@ -27,6 +31,42 @@ add_retries = retry_builder(tries=5, delay=5, max_delay=30)
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_docs_with_additional_info(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, Any]:
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
|
||||
if cc_pair.last_time_perm_sync
|
||||
else 0.0
|
||||
)
|
||||
cc_pair.last_time_perm_sync = current_time
|
||||
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
)
|
||||
|
||||
docs_with_additional_info = {
|
||||
doc.id: doc.additional_info
|
||||
for doc_batch in doc_batch_generator
|
||||
for doc in doc_batch
|
||||
}
|
||||
|
||||
return docs_with_additional_info
|
||||
|
||||
|
||||
def _fetch_permissions_paginated(
|
||||
drive_service: Any, drive_file_id: str
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
@@ -122,8 +162,6 @@ def _fetch_google_permissions_for_document_id(
|
||||
def gdrive_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -131,10 +169,24 @@ def gdrive_doc_sync(
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
for doc in docs_with_additional_info:
|
||||
sync_details = cc_pair.auto_sync_options
|
||||
if sync_details is None:
|
||||
logger.error("Sync details not found for Google Drive")
|
||||
raise ValueError("Sync details not found for Google Drive")
|
||||
|
||||
# Here we run the connector to grab all the ids
|
||||
# this may grab ids before they are indexed but that is fine because
|
||||
# we create a document in postgres to hold the permissions info
|
||||
# until the indexing job has a chance to run
|
||||
docs_with_additional_info = _get_docs_with_additional_info(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
|
||||
for doc_id, doc_additional_info in docs_with_additional_info.items():
|
||||
ext_access = _fetch_google_permissions_for_document_id(
|
||||
db_session=db_session,
|
||||
drive_file_id=doc.additional_info,
|
||||
drive_file_id=doc_additional_info,
|
||||
raw_credentials_json=cc_pair.credential.credential_json,
|
||||
company_google_domains=[
|
||||
cast(dict[str, str], sync_details)["company_domain"]
|
||||
@@ -142,7 +194,7 @@ def gdrive_doc_sync(
|
||||
)
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc.id,
|
||||
doc_id=doc_id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
|
@@ -17,7 +17,6 @@ from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -105,9 +104,12 @@ def _fetch_group_members_paginated(
|
||||
def gdrive_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
sync_details = cc_pair.auto_sync_options
|
||||
if sync_details is None:
|
||||
logger.error("Sync details not found for Google Drive")
|
||||
raise ValueError("Sync details not found for Google Drive")
|
||||
|
||||
google_drive_creds, _ = get_google_drive_creds(
|
||||
cc_pair.credential.credential_json,
|
||||
scopes=FETCH_GROUPS_SCOPES,
|
||||
|
@@ -4,32 +4,68 @@ from datetime import timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||
from danswer.document_index.factory import get_current_primary_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
DOC_PERMISSIONS_FUNC_MAP,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
FULL_FETCH_PERIOD_IN_SECONDS,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
GROUP_PERMISSIONS_FUNC_MAP,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_utils import (
|
||||
get_docs_with_additional_info,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_permission_sync_entrypoint(
|
||||
# None means that the connector runs every time
|
||||
_RESTRICTED_FETCH_PERIOD: dict[DocumentSource, int | None] = {
|
||||
# Polling is supported
|
||||
DocumentSource.GOOGLE_DRIVE: None,
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
}
|
||||
|
||||
|
||||
def run_external_group_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
|
||||
if group_sync_func is None:
|
||||
# Not all sync connectors support group permissions so this is fine
|
||||
return
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing groups for {source_type}")
|
||||
if group_sync_func is not None:
|
||||
group_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
)
|
||||
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document index: {e}")
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def run_external_doc_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
# TODO: seperate out group and doc sync
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
@@ -37,20 +73,16 @@ def run_permission_sync_entrypoint(
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No permission sync function found for source type: {source_type}"
|
||||
)
|
||||
|
||||
sync_details = cc_pair.auto_sync_options
|
||||
if sync_details is None:
|
||||
raise ValueError(f"No auto sync options found for source type: {source_type}")
|
||||
|
||||
# If the source type is not polling, we only fetch the permissions every
|
||||
# _FULL_FETCH_PERIOD_IN_SECONDS seconds
|
||||
full_fetch_period = FULL_FETCH_PERIOD_IN_SECONDS[source_type]
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
# If RESTRICTED_FETCH_PERIOD is not None, we only run sync if the
|
||||
# last sync was more than RESTRICTED_FETCH_PERIOD seconds ago.
|
||||
full_fetch_period = _RESTRICTED_FETCH_PERIOD[cc_pair.connector.source]
|
||||
if full_fetch_period is not None:
|
||||
last_sync = cc_pair.last_time_perm_sync
|
||||
if (
|
||||
@@ -62,65 +94,45 @@ def run_permission_sync_entrypoint(
|
||||
):
|
||||
return
|
||||
|
||||
# Here we run the connector to grab all the ids
|
||||
# this may grab ids before they are indexed but that is fine because
|
||||
# we create a document in postgres to hold the permissions info
|
||||
# until the indexing job has a chance to run
|
||||
docs_with_additional_info = get_docs_with_additional_info(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing groups for {source_type}")
|
||||
if group_sync_func is not None:
|
||||
group_sync_func(
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing docs for {source_type}")
|
||||
doc_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
docs_with_additional_info,
|
||||
sync_details,
|
||||
)
|
||||
|
||||
# This function updates:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing docs for {source_type}")
|
||||
doc_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
docs_with_additional_info,
|
||||
sync_details,
|
||||
)
|
||||
# Get the document ids for the cc pair
|
||||
document_ids_for_cc_pair = get_document_ids_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# This function fetches the updated access for the documents
|
||||
# and returns a dictionary of document_ids and access
|
||||
# This is the access we want to update vespa with
|
||||
docs_access = get_access_for_documents(
|
||||
document_ids=[doc.id for doc in docs_with_additional_info],
|
||||
db_session=db_session,
|
||||
)
|
||||
# This function fetches the updated access for the documents
|
||||
# and returns a dictionary of document_ids and access
|
||||
# This is the access we want to update vespa with
|
||||
docs_access = get_access_for_documents(
|
||||
document_ids=document_ids_for_cc_pair,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Then we build the update requests to update vespa
|
||||
update_reqs = [
|
||||
UpdateRequest(document_ids=[doc_id], access=doc_access)
|
||||
for doc_id, doc_access in docs_access.items()
|
||||
]
|
||||
# Then we build the update requests to update vespa
|
||||
update_reqs = [
|
||||
UpdateRequest(document_ids=[doc_id], access=doc_access)
|
||||
for doc_id, doc_access in docs_access.items()
|
||||
]
|
||||
|
||||
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
|
||||
document_index = get_current_primary_default_document_index(db_session)
|
||||
|
||||
try:
|
||||
# update vespa
|
||||
document_index.update(update_reqs)
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document index: {e}")
|
||||
logger.error(f"Error Syncing Permissions: {e}")
|
||||
db_session.rollback()
|
||||
|
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -9,15 +8,14 @@ from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_s
|
||||
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
|
||||
None,
|
||||
]
|
||||
|
||||
DocSyncFuncType = Callable[
|
||||
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
|
||||
# Defining the input/output types for the sync functions
|
||||
SyncFuncType = Callable[
|
||||
[
|
||||
Session,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
None,
|
||||
]
|
||||
|
||||
@@ -26,7 +24,7 @@ DocSyncFuncType = Callable[
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
||||
}
|
||||
@@ -35,20 +33,11 @@ DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_group_sync,
|
||||
}
|
||||
|
||||
|
||||
# None means that the connector supports polling from last_time_perm_sync to now
|
||||
FULL_FETCH_PERIOD_IN_SECONDS: dict[DocumentSource, int | None] = {
|
||||
# Polling is supported
|
||||
DocumentSource.GOOGLE_DRIVE: None,
|
||||
# Polling is not supported so we fetch all doc permissions every 10 minutes
|
||||
DocumentSource.CONFLUENCE: 10 * 60,
|
||||
}
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
|
@@ -1,56 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocsWithAdditionalInfo(BaseModel):
|
||||
id: str
|
||||
additional_info: Any
|
||||
|
||||
|
||||
def get_docs_with_additional_info(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocsWithAdditionalInfo]:
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
|
||||
if cc_pair.last_time_perm_sync
|
||||
else 0
|
||||
)
|
||||
cc_pair.last_time_perm_sync = current_time
|
||||
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
)
|
||||
|
||||
docs_with_additional_info = [
|
||||
DocsWithAdditionalInfo(id=doc.id, additional_info=doc.additional_info)
|
||||
for doc_batch in doc_batch_generator
|
||||
for doc in doc_batch
|
||||
]
|
||||
logger.debug(f"Docs with additional info: {len(docs_with_additional_info)}")
|
||||
|
||||
return docs_with_additional_info
|
Reference in New Issue
Block a user