Added confluence permission syncing (#2537)

* Added confluence permission syncing

* seperated out group and doc syncing

* minorbugfix and mypy

* added frontend and fixed bug

* Minor refactor

* dealth with confluence rate limits!

* mypy fixes!!!

* addressed yuhong feedback

* primary key fix
This commit is contained in:
hagen-danswer 2024-09-26 15:10:41 -07:00 committed by GitHub
parent 6d48fd5d99
commit b97cc01bb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 725 additions and 222 deletions

View File

@ -0,0 +1,46 @@
"""fix_user__external_user_group_id_fk
Revision ID: 46b7a812670f
Revises: f32615f71aeb
Create Date: 2024-09-23 12:58:03.894038
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "46b7a812670f"
down_revision = "f32615f71aeb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Add the new composite primary key
op.create_primary_key(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
["user_id", "external_user_group_id", "cc_pair_id"],
)
def downgrade() -> None:
# Drop the composite primary key
op.drop_constraint(
"user__external_user_group_id_pkey",
"user__external_user_group_id",
type_="primary",
)
# Delete all entries from the table
op.execute("DELETE FROM user__external_user_group_id")
# Recreate the original primary key on user_id
op.create_primary_key(
"user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"]
)

View File

@ -0,0 +1,32 @@
import bs4
def build_confluence_document_id(base_url: str, content_url: str) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
Args:
base_url (str): The base url of the Confluence instance
content_url (str): The url of the page or attachment download url
Returns:
str: The document id
"""
return f"{base_url}{content_url}"
def get_used_attachments(text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
Returns:
list[str]: List of filenames currently in use by the page text
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used

View File

@ -22,6 +22,10 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.confluence_utils import (
build_confluence_document_id,
)
from danswer.connectors.confluence.confluence_utils import get_used_attachments
from danswer.connectors.confluence.rate_limit_handler import (
make_confluence_call_handle_rate_limit,
)
@ -105,24 +109,6 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
return format_document_soup(soup)
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachment in used
Args:
text (str): The page content
confluence_client (Confluence): Confluence client
Returns:
list[str]: List of filename currently in used
"""
files_in_used = []
soup = bs4.BeautifulSoup(text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
files_in_used.append(attachment.attrs["ri:filename"])
return files_in_used
def _comment_dfs(
comments_str: str,
comment_pages: Collection[dict[str, Any]],
@ -624,13 +610,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
page_html = (
page["body"].get("storage", page["body"].get("view", {})).get("value")
)
page_url = self.wiki_base + page["_links"]["webui"]
# The url and the id are the same
page_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"]
)
if not page_html:
logger.debug("Page is empty, skipping: %s", page_url)
continue
page_text = parse_html_page(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html)
attachment_text, unused_page_attachments = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
)
@ -683,8 +672,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if time_filter and not time_filter(last_updated):
continue
attachment_url = self._attachment_to_download_link(
self.confluence_client, attachment
# The url and the id are the same
attachment_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["download"]
)
attachment_content = self._attachment_to_content(
self.confluence_client, attachment

View File

@ -104,6 +104,18 @@ def construct_document_select_for_connector_credential_pair(
return stmt
def get_document_ids_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> list[str]:
doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
return list(db_session.execute(doc_ids_stmt).scalars().all())
def get_documents_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> Sequence[DbDocument]:
@ -120,8 +132,8 @@ def get_documents_for_connector_credential_pair(
def get_documents_by_ids(
document_ids: list[str],
db_session: Session,
document_ids: list[str],
) -> list[DbDocument]:
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
documents = db_session.execute(stmt).scalars().all()

View File

@ -1725,7 +1725,9 @@ class User__ExternalUserGroupId(Base):
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
# These group ids have been prefixed by the source type
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
cc_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"), primary_key=True
)
class UsageReport(Base):

View File

@ -1,3 +1,6 @@
from sqlalchemy.orm import Session
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.vespa.index import VespaIndex
@ -13,3 +16,14 @@ def get_default_document_index(
return VespaIndex(
index_name=primary_index_name, secondary_index_name=secondary_index_name
)
def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
"""
TODO: Use redis to cache this or something
"""
search_settings = get_current_search_settings(db_session)
return get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=None,
)

View File

@ -220,8 +220,8 @@ def index_doc_batch_prepare(
document_ids = [document.id for document in documents]
db_docs: list[DBDocument] = get_documents_by_ids(
document_ids=document_ids,
db_session=db_session,
document_ids=document_ids,
)
# Skip indexing docs that don't have a newer updated at

View File

@ -11,12 +11,25 @@ from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
from ee.danswer.background.celery_utils import should_perform_external_permissions_check
from ee.danswer.background.celery_utils import (
should_perform_external_doc_permissions_check,
)
from ee.danswer.background.celery_utils import (
should_perform_external_group_permissions_check,
)
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.external_permissions.permission_sync import (
run_permission_sync_entrypoint,
run_external_doc_permission_sync,
)
from ee.danswer.external_permissions.permission_sync import (
run_external_group_permission_sync,
)
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
@ -26,11 +39,18 @@ logger = setup_logger()
global_version.set_ee()
@build_celery_task_wrapper(name_sync_external_permissions_task)
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_permissions_task(cc_pair_id: int) -> None:
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id)
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_chat_ttl_task)
@ -44,18 +64,35 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
# Periodic Tasks
#####
@celery_app.task(
name="check_sync_external_permissions_task",
name="check_sync_external_doc_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_permissions_task() -> None:
def check_sync_external_doc_permissions_task() -> None:
"""Runs periodically to sync external permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_permissions_check(
if should_perform_external_doc_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_permissions_task.apply_async(
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
)
@celery_app.task(
name="check_sync_external_group_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_group_permissions_task() -> None:
"""Runs periodically to sync external group permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_group_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_group_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
)
@ -94,9 +131,13 @@ def autogenerate_usage_report_task() -> None:
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"sync-external-permissions": {
"task": "check_sync_external_permissions_task",
"schedule": timedelta(seconds=60), # TODO: optimize this
"sync-external-doc-permissions": {
"task": "check_sync_external_doc_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"sync-external-group-permissions": {
"task": "check_sync_external_group_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"autogenerate_usage_report": {
"task": "autogenerate_usage_report_task",

View File

@ -12,7 +12,12 @@ from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)
from ee.danswer.db.user_group import delete_user_group
from ee.danswer.db.user_group import fetch_user_group
from ee.danswer.db.user_group import mark_user_group_as_synced
@ -38,13 +43,32 @@ def should_perform_chat_ttl_check(
return True
def should_perform_external_permissions_check(
def should_perform_external_doc_permissions_check(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
task_name = name_sync_external_permissions_task(cc_pair_id=cc_pair.id)
task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair.id)
latest_task = get_latest_task(task_name, db_session)
if not latest_task:
return True
if check_task_is_live_and_not_timed_out(latest_task, db_session):
logger.debug(f"{task_name} is already being performed. Skipping.")
return False
return True
def should_perform_external_group_permissions_check(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id)
latest_task = get_latest_task(task_name, db_session)
if not latest_task:

View File

@ -2,5 +2,9 @@ def name_chat_ttl_task(retention_limit_days: int) -> str:
return f"chat_ttl_{retention_limit_days}_days"
def name_sync_external_permissions_task(cc_pair_id: int) -> str:
return f"sync_external_permissions_task__{cc_pair_id}"
def name_sync_external_doc_permissions_task(cc_pair_id: int) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"
def name_sync_external_group_permissions_task(cc_pair_id: int) -> str:
return f"sync_external_group_permissions_task__{cc_pair_id}"

View File

@ -0,0 +1,18 @@
from typing import Any
from atlassian import Confluence # type:ignore
def build_confluence_client(
connector_specific_config: dict[str, Any], raw_credentials_json: dict[str, Any]
) -> Confluence:
is_cloud = connector_specific_config.get("is_cloud", False)
return Confluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=connector_specific_config["wiki_base"].rstrip("/"),
# passing in username causes issues for Confluence data center
username=raw_credentials_json["confluence_username"] if is_cloud else None,
password=raw_credentials_json["confluence_access_token"] if is_cloud else None,
token=raw_credentials_json["confluence_access_token"] if not is_cloud else None,
)

View File

@ -1,19 +1,254 @@
from typing import Any
from atlassian import Confluence # type:ignore
from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.confluence.confluence_utils import (
build_confluence_document_id,
)
from danswer.connectors.confluence.rate_limit_handler import (
make_confluence_call_handle_rate_limit,
)
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
from ee.danswer.db.document import upsert_document_external_perms__no_commit
from ee.danswer.external_permissions.confluence.confluence_sync_utils import (
build_confluence_client,
)
logger = setup_logger()
_REQUEST_PAGINATION_LIMIT = 100
def _get_space_permissions(
db_session: Session,
confluence_client: Confluence,
space_id: str,
) -> ExternalAccess:
get_space_permissions = make_confluence_call_handle_rate_limit(
confluence_client.get_space_permissions
)
space_permissions = get_space_permissions(space_id).get("permissions", [])
user_emails = set()
# Confluence enforces that group names are unique
group_names = set()
is_externally_public = False
for permission in space_permissions:
subs = permission.get("subjects")
if subs:
# If there are subjects, then there are explicit users or groups with access
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
user_emails.add(email)
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
group_names.add(group_name)
else:
# If there are no subjects, then the permission is for everyone
if permission.get("operation", {}).get(
"operation"
) == "read" and permission.get("anonymousAccess", False):
# If the permission specifies read access for anonymous users, then
# the space is publicly accessible
is_externally_public = True
batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=list(user_emails)
)
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,
is_public=is_externally_public,
)
def _get_restrictions_for_page(
db_session: Session,
page: dict[str, Any],
space_permissions: ExternalAccess,
) -> ExternalAccess:
"""
WARNING: This function includes no pagination. So if a page is private within
the space and has over 200 users or over 200 groups with explicitly read access,
this function will leave out some users or groups.
200 is a large amount so it is unlikely, but just be aware.
"""
restrictions_json = page.get("restrictions", {})
read_access_dict = restrictions_json.get("read", {}).get("restrictions", {})
read_access_user_jsons = read_access_dict.get("user", {}).get("results", [])
read_access_group_jsons = read_access_dict.get("group", {}).get("results", [])
is_space_public = read_access_user_jsons == [] and read_access_group_jsons == []
if not is_space_public:
read_access_user_emails = [
user["email"] for user in read_access_user_jsons if user.get("email")
]
read_access_groups = [group["name"] for group in read_access_group_jsons]
batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=list(read_access_user_emails)
)
external_access = ExternalAccess(
external_user_emails=set(read_access_user_emails),
external_user_group_ids=set(read_access_groups),
is_public=False,
)
else:
external_access = space_permissions
return external_access
def _fetch_attachment_document_ids_for_page_paginated(
confluence_client: Confluence, page: dict[str, Any]
) -> list[str]:
"""
Starts by just extracting the first page of attachments from
the page. If all attachments are in the first page, then
no calls to the api are made from this function.
"""
get_attachments_from_content = make_confluence_call_handle_rate_limit(
confluence_client.get_attachments_from_content
)
attachment_doc_ids = []
attachments_dict = page["children"]["attachment"]
start = 0
while True:
attachments_list = attachments_dict["results"]
attachment_doc_ids.extend(
[
build_confluence_document_id(
base_url=confluence_client.url,
content_url=attachment["_links"]["download"],
)
for attachment in attachments_list
]
)
if "next" not in attachments_dict["_links"]:
break
start += len(attachments_list)
attachments_dict = get_attachments_from_content(
page_id=page["id"],
start=start,
limit=_REQUEST_PAGINATION_LIMIT,
)
return attachment_doc_ids
def _fetch_all_pages_paginated(
confluence_client: Confluence,
space_id: str,
) -> list[dict[str, Any]]:
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
confluence_client.get_all_pages_from_space
)
# For each page, this fetches the page's attachments and restrictions.
expansion_strings = [
"children.attachment",
"restrictions.read.restrictions.user",
"restrictions.read.restrictions.group",
]
expansion_string = ",".join(expansion_strings)
all_pages = []
start = 0
while True:
pages_dict = get_all_pages_from_space(
space=space_id,
start=start,
limit=_REQUEST_PAGINATION_LIMIT,
expand=expansion_string,
)
all_pages.extend(pages_dict)
response_size = len(pages_dict)
if response_size < _REQUEST_PAGINATION_LIMIT:
break
start += response_size
return all_pages
def _fetch_all_page_restrictions_for_space(
db_session: Session,
confluence_client: Confluence,
space_id: str,
space_permissions: ExternalAccess,
) -> dict[str, ExternalAccess]:
all_pages = _fetch_all_pages_paginated(
confluence_client=confluence_client,
space_id=space_id,
)
document_restrictions: dict[str, ExternalAccess] = {}
for page in all_pages:
"""
This assigns the same permissions to all attachments of a page and
the page itself.
This is because the attachments are stored in the same Confluence space as the page.
WARNING: We create a dbDocument entry for all attachments, even though attachments
may not be their own standalone documents. This is likely fine as we just upsert a
document with just permissions.
"""
attachment_document_ids = [
build_confluence_document_id(
base_url=confluence_client.url,
content_url=page["_links"]["webui"],
)
]
attachment_document_ids.extend(
_fetch_attachment_document_ids_for_page_paginated(
confluence_client=confluence_client, page=page
)
)
page_permissions = _get_restrictions_for_page(
db_session=db_session,
page=page,
space_permissions=space_permissions,
)
for attachment_document_id in attachment_document_ids:
document_restrictions[attachment_document_id] = page_permissions
return document_restrictions
def confluence_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
logger.debug("Not yet implemented ACL sync for confluence, no-op")
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
it in postgres so that when it gets created later, the permissions are
already populated
"""
confluence_client = build_confluence_client(
cc_pair.connector.connector_specific_config, cc_pair.credential.credential_json
)
space_permissions = _get_space_permissions(
db_session=db_session,
confluence_client=confluence_client,
space_id=cc_pair.connector.connector_specific_config["space"],
)
fresh_doc_permissions = _fetch_all_page_restrictions_for_space(
db_session=db_session,
confluence_client=confluence_client,
space_id=cc_pair.connector.connector_specific_config["space"],
space_permissions=space_permissions,
)
for doc_id, ext_access in fresh_doc_permissions.items():
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=doc_id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)

View File

@ -1,19 +1,107 @@
from typing import Any
from collections.abc import Iterator
from atlassian import Confluence # type:ignore
from requests import HTTPError
from sqlalchemy.orm import Session
from danswer.connectors.confluence.rate_limit_handler import (
make_confluence_call_handle_rate_limit,
)
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.confluence.confluence_sync_utils import (
build_confluence_client,
)
logger = setup_logger()
_PAGE_SIZE = 100
def _get_confluence_group_names_paginated(
confluence_client: Confluence,
) -> Iterator[str]:
get_all_groups = make_confluence_call_handle_rate_limit(
confluence_client.get_all_groups
)
start = 0
while True:
try:
groups = get_all_groups(start=start, limit=_PAGE_SIZE)
except HTTPError as e:
if e.response.status_code in (403, 404):
return
raise e
for group in groups:
if group_name := group.get("name"):
yield group_name
if len(groups) < _PAGE_SIZE:
break
start += _PAGE_SIZE
def _get_group_members_email_paginated(
confluence_client: Confluence,
group_name: str,
) -> list[str]:
get_group_members = make_confluence_call_handle_rate_limit(
confluence_client.get_group_members
)
group_member_emails: list[str] = []
start = 0
while True:
try:
members = get_group_members(
group_name=group_name, start=start, limit=_PAGE_SIZE
)
except HTTPError as e:
if e.response.status_code == 403 or e.response.status_code == 404:
return group_member_emails
raise e
group_member_emails.extend(
[member.get("email") for member in members if member.get("email")]
)
if len(members) < _PAGE_SIZE:
break
start += _PAGE_SIZE
return group_member_emails
def confluence_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
logger.debug("Not yet implemented group sync for confluence, no-op")
confluence_client = build_confluence_client(
cc_pair.connector.connector_specific_config, cc_pair.credential.credential_json
)
danswer_groups: list[ExternalUserGroup] = []
# Confluence enforces that group names are unique
for group_name in _get_confluence_group_names_paginated(confluence_client):
group_member_emails = _get_group_members_email_paginated(
confluence_client, group_name
)
group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=group_member_emails
)
if group_members:
danswer_groups.append(
ExternalUserGroup(
id=group_name, user_ids=[user.id for user in group_members]
)
)
replace_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=danswer_groups,
source=cc_pair.connector.source,
)

View File

@ -1,4 +1,6 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
@ -8,15 +10,17 @@ from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds,
)
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import InputType
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
@ -27,6 +31,42 @@ add_retries = retry_builder(tries=5, delay=5, max_delay=30)
logger = setup_logger()
def _get_docs_with_additional_info(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> dict[str, Any]:
# Get all document ids that need their permissions updated
runnable_connector = instantiate_connector(
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.POLL,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
assert isinstance(runnable_connector, PollConnector)
current_time = datetime.now(timezone.utc)
start_time = (
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
if cc_pair.last_time_perm_sync
else 0.0
)
cc_pair.last_time_perm_sync = current_time
doc_batch_generator = runnable_connector.poll_source(
start=start_time, end=current_time.timestamp()
)
docs_with_additional_info = {
doc.id: doc.additional_info
for doc_batch in doc_batch_generator
for doc in doc_batch
}
return docs_with_additional_info
def _fetch_permissions_paginated(
drive_service: Any, drive_file_id: str
) -> Iterator[dict[str, Any]]:
@ -122,8 +162,6 @@ def _fetch_google_permissions_for_document_id(
def gdrive_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
"""
Adds the external permissions to the documents in postgres
@ -131,10 +169,24 @@ def gdrive_doc_sync(
it in postgres so that when it gets created later, the permissions are
already populated
"""
for doc in docs_with_additional_info:
sync_details = cc_pair.auto_sync_options
if sync_details is None:
logger.error("Sync details not found for Google Drive")
raise ValueError("Sync details not found for Google Drive")
# Here we run the connector to grab all the ids
# this may grab ids before they are indexed but that is fine because
# we create a document in postgres to hold the permissions info
# until the indexing job has a chance to run
docs_with_additional_info = _get_docs_with_additional_info(
db_session=db_session,
cc_pair=cc_pair,
)
for doc_id, doc_additional_info in docs_with_additional_info.items():
ext_access = _fetch_google_permissions_for_document_id(
db_session=db_session,
drive_file_id=doc.additional_info,
drive_file_id=doc_additional_info,
raw_credentials_json=cc_pair.credential.credential_json,
company_google_domains=[
cast(dict[str, str], sync_details)["company_domain"]
@ -142,7 +194,7 @@ def gdrive_doc_sync(
)
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=doc.id,
doc_id=doc_id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)

View File

@ -17,7 +17,6 @@ from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
logger = setup_logger()
@ -105,9 +104,12 @@ def _fetch_group_members_paginated(
def gdrive_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
sync_details = cc_pair.auto_sync_options
if sync_details is None:
logger.error("Sync details not found for Google Drive")
raise ValueError("Sync details not found for Google Drive")
google_drive_creds, _ = get_google_drive_creds(
cc_pair.credential.credential_json,
scopes=FETCH_GROUPS_SCOPES,

View File

@ -4,32 +4,68 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.db.document import get_document_ids_for_connector_credential_pair
from danswer.document_index.factory import get_current_primary_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.permission_sync_function_map import (
DOC_PERMISSIONS_FUNC_MAP,
)
from ee.danswer.external_permissions.permission_sync_function_map import (
FULL_FETCH_PERIOD_IN_SECONDS,
)
from ee.danswer.external_permissions.permission_sync_function_map import (
GROUP_PERMISSIONS_FUNC_MAP,
)
from ee.danswer.external_permissions.permission_sync_utils import (
get_docs_with_additional_info,
)
logger = setup_logger()
def run_permission_sync_entrypoint(
# None means that the connector runs every time
_RESTRICTED_FETCH_PERIOD: dict[DocumentSource, int | None] = {
# Polling is supported
DocumentSource.GOOGLE_DRIVE: None,
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: 5 * 60,
}
def run_external_group_permission_sync(
db_session: Session,
cc_pair_id: int,
) -> None:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
source_type = cc_pair.connector.source
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if group_sync_func is None:
# Not all sync connectors support group permissions so this is fine
return
try:
# This function updates:
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
logger.debug(f"Syncing groups for {source_type}")
if group_sync_func is not None:
group_sync_func(
db_session,
cc_pair,
)
# update postgres
db_session.commit()
except Exception as e:
logger.error(f"Error updating document index: {e}")
db_session.rollback()
def run_external_doc_permission_sync(
db_session: Session,
cc_pair_id: int,
) -> None:
# TODO: seperate out group and doc sync
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
@ -37,20 +73,16 @@ def run_permission_sync_entrypoint(
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(
f"No permission sync function found for source type: {source_type}"
)
sync_details = cc_pair.auto_sync_options
if sync_details is None:
raise ValueError(f"No auto sync options found for source type: {source_type}")
# If the source type is not polling, we only fetch the permissions every
# _FULL_FETCH_PERIOD_IN_SECONDS seconds
full_fetch_period = FULL_FETCH_PERIOD_IN_SECONDS[source_type]
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
# If RESTRICTED_FETCH_PERIOD is not None, we only run sync if the
# last sync was more than RESTRICTED_FETCH_PERIOD seconds ago.
full_fetch_period = _RESTRICTED_FETCH_PERIOD[cc_pair.connector.source]
if full_fetch_period is not None:
last_sync = cc_pair.last_time_perm_sync
if (
@ -62,65 +94,45 @@ def run_permission_sync_entrypoint(
):
return
# Here we run the connector to grab all the ids
# this may grab ids before they are indexed but that is fine because
# we create a document in postgres to hold the permissions info
# until the indexing job has a chance to run
docs_with_additional_info = get_docs_with_additional_info(
db_session=db_session,
cc_pair=cc_pair,
)
# This function updates:
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
logger.debug(f"Syncing groups for {source_type}")
if group_sync_func is not None:
group_sync_func(
try:
# This function updates:
# - the user_email <-> document mapping
# - the external_user_group_id <-> document mapping
# in postgres without committing
logger.debug(f"Syncing docs for {source_type}")
doc_sync_func(
db_session,
cc_pair,
docs_with_additional_info,
sync_details,
)
# This function updates:
# - the user_email <-> document mapping
# - the external_user_group_id <-> document mapping
# in postgres without committing
logger.debug(f"Syncing docs for {source_type}")
doc_sync_func(
db_session,
cc_pair,
docs_with_additional_info,
sync_details,
)
# Get the document ids for the cc pair
document_ids_for_cc_pair = get_document_ids_for_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# This function fetches the updated access for the documents
# and returns a dictionary of document_ids and access
# This is the access we want to update vespa with
docs_access = get_access_for_documents(
document_ids=[doc.id for doc in docs_with_additional_info],
db_session=db_session,
)
# This function fetches the updated access for the documents
# and returns a dictionary of document_ids and access
# This is the access we want to update vespa with
docs_access = get_access_for_documents(
document_ids=document_ids_for_cc_pair,
db_session=db_session,
)
# Then we build the update requests to update vespa
update_reqs = [
UpdateRequest(document_ids=[doc_id], access=doc_access)
for doc_id, doc_access in docs_access.items()
]
# Then we build the update requests to update vespa
update_reqs = [
UpdateRequest(document_ids=[doc_id], access=doc_access)
for doc_id, doc_access in docs_access.items()
]
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=None,
)
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
document_index = get_current_primary_default_document_index(db_session)
try:
# update vespa
document_index.update(update_reqs)
# update postgres
db_session.commit()
except Exception as e:
logger.error(f"Error updating document index: {e}")
logger.error(f"Error Syncing Permissions: {e}")
db_session.rollback()

View File

@ -1,5 +1,4 @@
from collections.abc import Callable
from typing import Any
from sqlalchemy.orm import Session
@ -9,15 +8,14 @@ from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_s
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
GroupSyncFuncType = Callable[
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
None,
]
DocSyncFuncType = Callable[
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
# Defining the input/output types for the sync functions
SyncFuncType = Callable[
[
Session,
ConnectorCredentialPair,
],
None,
]
@ -26,7 +24,7 @@ DocSyncFuncType = Callable[
# - the external_user_group_id <-> document mapping
# in postgres without committing
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
DocumentSource.CONFLUENCE: confluence_doc_sync,
}
@ -35,20 +33,11 @@ DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
DocumentSource.CONFLUENCE: confluence_group_sync,
}
# None means that the connector supports polling from last_time_perm_sync to now
FULL_FETCH_PERIOD_IN_SECONDS: dict[DocumentSource, int | None] = {
# Polling is supported
DocumentSource.GOOGLE_DRIVE: None,
# Polling is not supported so we fetch all doc permissions every 10 minutes
DocumentSource.CONFLUENCE: 10 * 60,
}
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
return source_type in DOC_PERMISSIONS_FUNC_MAP

View File

@ -1,56 +0,0 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import InputType
from danswer.db.models import ConnectorCredentialPair
from danswer.utils.logger import setup_logger
logger = setup_logger()
class DocsWithAdditionalInfo(BaseModel):
id: str
additional_info: Any
def get_docs_with_additional_info(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> list[DocsWithAdditionalInfo]:
# Get all document ids that need their permissions updated
runnable_connector = instantiate_connector(
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.POLL,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
assert isinstance(runnable_connector, PollConnector)
current_time = datetime.now(timezone.utc)
start_time = (
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
if cc_pair.last_time_perm_sync
else 0
)
cc_pair.last_time_perm_sync = current_time
doc_batch_generator = runnable_connector.poll_source(
start=start_time, end=current_time.timestamp()
)
docs_with_additional_info = [
DocsWithAdditionalInfo(id=doc.id, additional_info=doc.additional_info)
for doc_batch in doc_batch_generator
for doc in doc_batch
]
logger.debug(f"Docs with additional info: {len(docs_with_additional_info)}")
return docs_with_additional_info

View File

@ -5,12 +5,9 @@ import {
ConfigurableSources,
validAutoSyncSources,
} from "@/lib/types";
import { Text, Title } from "@tremor/react";
import { useUser } from "@/components/user/UserProvider";
import { useField } from "formik";
import { AutoSyncOptions } from "./AutoSyncOptions";
import { useContext } from "react";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
function isValidAutoSyncSource(
@ -28,7 +25,6 @@ export function AccessTypeForm({
useField<AccessType>("access_type");
const isPaidEnterpriseEnabled = usePaidEnterpriseFeaturesEnabled();
const settings = useContext(SettingsContext);
const isAutoSyncSupported = isValidAutoSyncSource(connector);
const { isLoadingUser, isAdmin } = useUser();
@ -85,7 +81,7 @@ export function AccessTypeForm({
/>
{access_type.value === "sync" && isAutoSyncSupported && (
<div className="mt-6">
<div>
<AutoSyncOptions
connectorType={connector as ValidAutoSyncSources}
/>

View File

@ -1,5 +1,4 @@
import { TextFormField } from "@/components/admin/connectors/Field";
import { useFormikContext } from "formik";
import { ValidAutoSyncSources } from "@/lib/types";
import { Divider } from "@tremor/react";
import { autoSyncConfigBySource } from "@/lib/connectors/AutoSyncOptionFields";
@ -9,22 +8,24 @@ export function AutoSyncOptions({
}: {
connectorType: ValidAutoSyncSources;
}) {
const autoSyncConfig = autoSyncConfigBySource[connectorType];
if (Object.keys(autoSyncConfig).length === 0) {
return null;
}
return (
<div>
<Divider />
<>
{Object.entries(autoSyncConfigBySource[connectorType]).map(
([key, config]) => (
<div key={key} className="mb-4">
<TextFormField
name={`auto_sync_options.${key}`}
label={config.label}
subtext={config.subtext}
/>
</div>
)
)}
</>
{Object.entries(autoSyncConfig).map(([key, config]) => (
<div key={key} className="mb-4">
<TextFormField
name={`auto_sync_options.${key}`}
label={config.label}
subtext={config.subtext}
/>
</div>
))}
</div>
);
}

View File

@ -11,6 +11,7 @@ export const autoSyncConfigBySource: Record<
}
>
> = {
confluence: {},
google_drive: {
customer_id: {
label: "Google Workspace Customer ID",

View File

@ -264,5 +264,5 @@ export type ConfigurableSources = Exclude<
>;
// The sources that have auto-sync support on the backend
export const validAutoSyncSources = ["google_drive"] as const;
export const validAutoSyncSources = ["confluence", "google_drive"] as const;
export type ValidAutoSyncSources = (typeof validAutoSyncSources)[number];