mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-30 01:30:21 +02:00
Added confluence permission syncing (#2537)
* Added confluence permission syncing * seperated out group and doc syncing * minorbugfix and mypy * added frontend and fixed bug * Minor refactor * dealth with confluence rate limits! * mypy fixes!!! * addressed yuhong feedback * primary key fix
This commit is contained in:
parent
6d48fd5d99
commit
b97cc01bb2
@ -0,0 +1,46 @@
|
||||
"""fix_user__external_user_group_id_fk
|
||||
|
||||
Revision ID: 46b7a812670f
|
||||
Revises: f32615f71aeb
|
||||
Create Date: 2024-09-23 12:58:03.894038
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "46b7a812670f"
|
||||
down_revision = "f32615f71aeb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing primary key
|
||||
op.drop_constraint(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
type_="primary",
|
||||
)
|
||||
|
||||
# Add the new composite primary key
|
||||
op.create_primary_key(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
["user_id", "external_user_group_id", "cc_pair_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the composite primary key
|
||||
op.drop_constraint(
|
||||
"user__external_user_group_id_pkey",
|
||||
"user__external_user_group_id",
|
||||
type_="primary",
|
||||
)
|
||||
# Delete all entries from the table
|
||||
op.execute("DELETE FROM user__external_user_group_id")
|
||||
|
||||
# Recreate the original primary key on user_id
|
||||
op.create_primary_key(
|
||||
"user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"]
|
||||
)
|
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
32
backend/danswer/connectors/confluence/confluence_utils.py
Normal file
@ -0,0 +1,32 @@
|
||||
import bs4
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
Args:
|
||||
base_url (str): The base url of the Confluence instance
|
||||
content_url (str): The url of the page or attachment download url
|
||||
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def get_used_attachments(text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
@ -22,6 +22,10 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.confluence.confluence_utils import (
|
||||
build_confluence_document_id,
|
||||
)
|
||||
from danswer.connectors.confluence.confluence_utils import get_used_attachments
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
@ -105,24 +109,6 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def get_used_attachments(text: str, confluence_client: Confluence) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachment in used
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
Returns:
|
||||
list[str]: List of filename currently in used
|
||||
"""
|
||||
files_in_used = []
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
files_in_used.append(attachment.attrs["ri:filename"])
|
||||
return files_in_used
|
||||
|
||||
|
||||
def _comment_dfs(
|
||||
comments_str: str,
|
||||
comment_pages: Collection[dict[str, Any]],
|
||||
@ -624,13 +610,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
page_html = (
|
||||
page["body"].get("storage", page["body"].get("view", {})).get("value")
|
||||
)
|
||||
page_url = self.wiki_base + page["_links"]["webui"]
|
||||
# The url and the id are the same
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"]
|
||||
)
|
||||
if not page_html:
|
||||
logger.debug("Page is empty, skipping: %s", page_url)
|
||||
continue
|
||||
page_text = parse_html_page(page_html, self.confluence_client)
|
||||
|
||||
files_in_used = get_used_attachments(page_html, self.confluence_client)
|
||||
files_in_used = get_used_attachments(page_html)
|
||||
attachment_text, unused_page_attachments = self._fetch_attachments(
|
||||
self.confluence_client, page_id, files_in_used
|
||||
)
|
||||
@ -683,8 +672,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if time_filter and not time_filter(last_updated):
|
||||
continue
|
||||
|
||||
attachment_url = self._attachment_to_download_link(
|
||||
self.confluence_client, attachment
|
||||
# The url and the id are the same
|
||||
attachment_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["download"]
|
||||
)
|
||||
attachment_content = self._attachment_to_content(
|
||||
self.confluence_client, attachment
|
||||
|
@ -104,6 +104,18 @@ def construct_document_select_for_connector_credential_pair(
|
||||
return stmt
|
||||
|
||||
|
||||
def get_document_ids_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> list[str]:
|
||||
doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
return list(db_session.execute(doc_ids_stmt).scalars().all())
|
||||
|
||||
|
||||
def get_documents_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> Sequence[DbDocument]:
|
||||
@ -120,8 +132,8 @@ def get_documents_for_connector_credential_pair(
|
||||
|
||||
|
||||
def get_documents_by_ids(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> list[DbDocument]:
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
|
||||
documents = db_session.execute(stmt).scalars().all()
|
||||
|
@ -1725,7 +1725,9 @@ class User__ExternalUserGroupId(Base):
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
# These group ids have been prefixed by the source type
|
||||
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
|
||||
cc_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class UsageReport(Base):
|
||||
|
@ -1,3 +1,6 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
|
||||
@ -13,3 +16,14 @@ def get_default_document_index(
|
||||
return VespaIndex(
|
||||
index_name=primary_index_name, secondary_index_name=secondary_index_name
|
||||
)
|
||||
|
||||
|
||||
def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
|
||||
"""
|
||||
TODO: Use redis to cache this or something
|
||||
"""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
|
@ -220,8 +220,8 @@ def index_doc_batch_prepare(
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
db_docs: list[DBDocument] = get_documents_by_ids(
|
||||
document_ids=document_ids,
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
# Skip indexing docs that don't have a newer updated at
|
||||
|
@ -11,12 +11,25 @@ from danswer.server.settings.store import load_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.danswer.background.celery_utils import should_perform_external_permissions_check
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_doc_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_group_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_permission_sync_entrypoint,
|
||||
run_external_doc_permission_sync,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_external_group_permission_sync,
|
||||
)
|
||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||
|
||||
@ -26,11 +39,18 @@ logger = setup_logger()
|
||||
global_version.set_ee()
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_permissions_task)
|
||||
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_permissions_task(cc_pair_id: int) -> None:
|
||||
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@ -44,18 +64,35 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_sync_external_permissions_task",
|
||||
name="check_sync_external_doc_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_permissions_task() -> None:
|
||||
def check_sync_external_doc_permissions_task() -> None:
|
||||
"""Runs periodically to sync external permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_permissions_check(
|
||||
if should_perform_external_doc_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_permissions_task.apply_async(
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_sync_external_group_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_group_permissions_task() -> None:
|
||||
"""Runs periodically to sync external group permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_group_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_group_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
)
|
||||
|
||||
@ -94,9 +131,13 @@ def autogenerate_usage_report_task() -> None:
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"sync-external-permissions": {
|
||||
"task": "check_sync_external_permissions_task",
|
||||
"schedule": timedelta(seconds=60), # TODO: optimize this
|
||||
"sync-external-doc-permissions": {
|
||||
"task": "check_sync_external_doc_permissions_task",
|
||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||
},
|
||||
"sync-external-group-permissions": {
|
||||
"task": "check_sync_external_group_permissions_task",
|
||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||
},
|
||||
"autogenerate_usage_report": {
|
||||
"task": "autogenerate_usage_report_task",
|
||||
|
@ -12,7 +12,12 @@ from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
from ee.danswer.db.user_group import mark_user_group_as_synced
|
||||
@ -38,13 +43,32 @@ def should_perform_chat_ttl_check(
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_external_permissions_check(
|
||||
def should_perform_external_doc_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
task_name = name_sync_external_permissions_task(cc_pair_id=cc_pair.id)
|
||||
task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
return True
|
||||
|
||||
if check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_external_group_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
|
@ -2,5 +2,9 @@ def name_chat_ttl_task(retention_limit_days: int) -> str:
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
def name_sync_external_permissions_task(cc_pair_id: int) -> str:
|
||||
return f"sync_external_permissions_task__{cc_pair_id}"
|
||||
def name_sync_external_doc_permissions_task(cc_pair_id: int) -> str:
|
||||
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
||||
|
||||
|
||||
def name_sync_external_group_permissions_task(cc_pair_id: int) -> str:
|
||||
return f"sync_external_group_permissions_task__{cc_pair_id}"
|
||||
|
@ -0,0 +1,18 @@
|
||||
from typing import Any
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
|
||||
|
||||
def build_confluence_client(
|
||||
connector_specific_config: dict[str, Any], raw_credentials_json: dict[str, Any]
|
||||
) -> Confluence:
|
||||
is_cloud = connector_specific_config.get("is_cloud", False)
|
||||
return Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=connector_specific_config["wiki_base"].rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=raw_credentials_json["confluence_username"] if is_cloud else None,
|
||||
password=raw_credentials_json["confluence_access_token"] if is_cloud else None,
|
||||
token=raw_credentials_json["confluence_access_token"] if not is_cloud else None,
|
||||
)
|
@ -1,19 +1,254 @@
|
||||
from typing import Any
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.confluence.confluence_utils import (
|
||||
build_confluence_document_id,
|
||||
)
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.confluence.confluence_sync_utils import (
|
||||
build_confluence_client,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_REQUEST_PAGINATION_LIMIT = 100
|
||||
|
||||
|
||||
def _get_space_permissions(
|
||||
db_session: Session,
|
||||
confluence_client: Confluence,
|
||||
space_id: str,
|
||||
) -> ExternalAccess:
|
||||
get_space_permissions = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_space_permissions
|
||||
)
|
||||
|
||||
space_permissions = get_space_permissions(space_id).get("permissions", [])
|
||||
user_emails = set()
|
||||
# Confluence enforces that group names are unique
|
||||
group_names = set()
|
||||
is_externally_public = False
|
||||
for permission in space_permissions:
|
||||
subs = permission.get("subjects")
|
||||
if subs:
|
||||
# If there are subjects, then there are explicit users or groups with access
|
||||
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
|
||||
user_emails.add(email)
|
||||
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
|
||||
group_names.add(group_name)
|
||||
else:
|
||||
# If there are no subjects, then the permission is for everyone
|
||||
if permission.get("operation", {}).get(
|
||||
"operation"
|
||||
) == "read" and permission.get("anonymousAccess", False):
|
||||
# If the permission specifies read access for anonymous users, then
|
||||
# the space is publicly accessible
|
||||
is_externally_public = True
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=list(user_emails)
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_externally_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_restrictions_for_page(
|
||||
db_session: Session,
|
||||
page: dict[str, Any],
|
||||
space_permissions: ExternalAccess,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
WARNING: This function includes no pagination. So if a page is private within
|
||||
the space and has over 200 users or over 200 groups with explicitly read access,
|
||||
this function will leave out some users or groups.
|
||||
200 is a large amount so it is unlikely, but just be aware.
|
||||
"""
|
||||
restrictions_json = page.get("restrictions", {})
|
||||
read_access_dict = restrictions_json.get("read", {}).get("restrictions", {})
|
||||
|
||||
read_access_user_jsons = read_access_dict.get("user", {}).get("results", [])
|
||||
read_access_group_jsons = read_access_dict.get("group", {}).get("results", [])
|
||||
|
||||
is_space_public = read_access_user_jsons == [] and read_access_group_jsons == []
|
||||
|
||||
if not is_space_public:
|
||||
read_access_user_emails = [
|
||||
user["email"] for user in read_access_user_jsons if user.get("email")
|
||||
]
|
||||
read_access_groups = [group["name"] for group in read_access_group_jsons]
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=list(read_access_user_emails)
|
||||
)
|
||||
external_access = ExternalAccess(
|
||||
external_user_emails=set(read_access_user_emails),
|
||||
external_user_group_ids=set(read_access_groups),
|
||||
is_public=False,
|
||||
)
|
||||
else:
|
||||
external_access = space_permissions
|
||||
|
||||
return external_access
|
||||
|
||||
|
||||
def _fetch_attachment_document_ids_for_page_paginated(
|
||||
confluence_client: Confluence, page: dict[str, Any]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Starts by just extracting the first page of attachments from
|
||||
the page. If all attachments are in the first page, then
|
||||
no calls to the api are made from this function.
|
||||
"""
|
||||
get_attachments_from_content = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_attachments_from_content
|
||||
)
|
||||
|
||||
attachment_doc_ids = []
|
||||
attachments_dict = page["children"]["attachment"]
|
||||
start = 0
|
||||
|
||||
while True:
|
||||
attachments_list = attachments_dict["results"]
|
||||
attachment_doc_ids.extend(
|
||||
[
|
||||
build_confluence_document_id(
|
||||
base_url=confluence_client.url,
|
||||
content_url=attachment["_links"]["download"],
|
||||
)
|
||||
for attachment in attachments_list
|
||||
]
|
||||
)
|
||||
|
||||
if "next" not in attachments_dict["_links"]:
|
||||
break
|
||||
|
||||
start += len(attachments_list)
|
||||
attachments_dict = get_attachments_from_content(
|
||||
page_id=page["id"],
|
||||
start=start,
|
||||
limit=_REQUEST_PAGINATION_LIMIT,
|
||||
)
|
||||
|
||||
return attachment_doc_ids
|
||||
|
||||
|
||||
def _fetch_all_pages_paginated(
|
||||
confluence_client: Confluence,
|
||||
space_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_pages_from_space
|
||||
)
|
||||
|
||||
# For each page, this fetches the page's attachments and restrictions.
|
||||
expansion_strings = [
|
||||
"children.attachment",
|
||||
"restrictions.read.restrictions.user",
|
||||
"restrictions.read.restrictions.group",
|
||||
]
|
||||
expansion_string = ",".join(expansion_strings)
|
||||
|
||||
all_pages = []
|
||||
start = 0
|
||||
while True:
|
||||
pages_dict = get_all_pages_from_space(
|
||||
space=space_id,
|
||||
start=start,
|
||||
limit=_REQUEST_PAGINATION_LIMIT,
|
||||
expand=expansion_string,
|
||||
)
|
||||
all_pages.extend(pages_dict)
|
||||
|
||||
response_size = len(pages_dict)
|
||||
if response_size < _REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
start += response_size
|
||||
|
||||
return all_pages
|
||||
|
||||
|
||||
def _fetch_all_page_restrictions_for_space(
|
||||
db_session: Session,
|
||||
confluence_client: Confluence,
|
||||
space_id: str,
|
||||
space_permissions: ExternalAccess,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
all_pages = _fetch_all_pages_paginated(
|
||||
confluence_client=confluence_client,
|
||||
space_id=space_id,
|
||||
)
|
||||
|
||||
document_restrictions: dict[str, ExternalAccess] = {}
|
||||
for page in all_pages:
|
||||
"""
|
||||
This assigns the same permissions to all attachments of a page and
|
||||
the page itself.
|
||||
This is because the attachments are stored in the same Confluence space as the page.
|
||||
WARNING: We create a dbDocument entry for all attachments, even though attachments
|
||||
may not be their own standalone documents. This is likely fine as we just upsert a
|
||||
document with just permissions.
|
||||
"""
|
||||
attachment_document_ids = [
|
||||
build_confluence_document_id(
|
||||
base_url=confluence_client.url,
|
||||
content_url=page["_links"]["webui"],
|
||||
)
|
||||
]
|
||||
attachment_document_ids.extend(
|
||||
_fetch_attachment_document_ids_for_page_paginated(
|
||||
confluence_client=confluence_client, page=page
|
||||
)
|
||||
)
|
||||
page_permissions = _get_restrictions_for_page(
|
||||
db_session=db_session,
|
||||
page=page,
|
||||
space_permissions=space_permissions,
|
||||
)
|
||||
for attachment_document_id in attachment_document_ids:
|
||||
document_restrictions[attachment_document_id] = page_permissions
|
||||
|
||||
return document_restrictions
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug("Not yet implemented ACL sync for confluence, no-op")
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
confluence_client = build_confluence_client(
|
||||
cc_pair.connector.connector_specific_config, cc_pair.credential.credential_json
|
||||
)
|
||||
space_permissions = _get_space_permissions(
|
||||
db_session=db_session,
|
||||
confluence_client=confluence_client,
|
||||
space_id=cc_pair.connector.connector_specific_config["space"],
|
||||
)
|
||||
fresh_doc_permissions = _fetch_all_page_restrictions_for_space(
|
||||
db_session=db_session,
|
||||
confluence_client=confluence_client,
|
||||
space_id=cc_pair.connector.connector_specific_config["space"],
|
||||
space_permissions=space_permissions,
|
||||
)
|
||||
for doc_id, ext_access in fresh_doc_permissions.items():
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
|
@ -1,19 +1,107 @@
|
||||
from typing import Any
|
||||
from collections.abc import Iterator
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.confluence.rate_limit_handler import (
|
||||
make_confluence_call_handle_rate_limit,
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.confluence.confluence_sync_utils import (
|
||||
build_confluence_client,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_PAGE_SIZE = 100
|
||||
|
||||
|
||||
def _get_confluence_group_names_paginated(
|
||||
confluence_client: Confluence,
|
||||
) -> Iterator[str]:
|
||||
get_all_groups = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_groups
|
||||
)
|
||||
|
||||
start = 0
|
||||
while True:
|
||||
try:
|
||||
groups = get_all_groups(start=start, limit=_PAGE_SIZE)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code in (403, 404):
|
||||
return
|
||||
raise e
|
||||
|
||||
for group in groups:
|
||||
if group_name := group.get("name"):
|
||||
yield group_name
|
||||
|
||||
if len(groups) < _PAGE_SIZE:
|
||||
break
|
||||
start += _PAGE_SIZE
|
||||
|
||||
|
||||
def _get_group_members_email_paginated(
|
||||
confluence_client: Confluence,
|
||||
group_name: str,
|
||||
) -> list[str]:
|
||||
get_group_members = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_group_members
|
||||
)
|
||||
group_member_emails: list[str] = []
|
||||
start = 0
|
||||
while True:
|
||||
try:
|
||||
members = get_group_members(
|
||||
group_name=group_name, start=start, limit=_PAGE_SIZE
|
||||
)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403 or e.response.status_code == 404:
|
||||
return group_member_emails
|
||||
raise e
|
||||
|
||||
group_member_emails.extend(
|
||||
[member.get("email") for member in members if member.get("email")]
|
||||
)
|
||||
if len(members) < _PAGE_SIZE:
|
||||
break
|
||||
start += _PAGE_SIZE
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug("Not yet implemented group sync for confluence, no-op")
|
||||
confluence_client = build_confluence_client(
|
||||
cc_pair.connector.connector_specific_config, cc_pair.credential.credential_json
|
||||
)
|
||||
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
# Confluence enforces that group names are unique
|
||||
for group_name in _get_confluence_group_names_paginated(confluence_client):
|
||||
group_member_emails = _get_group_members_email_paginated(
|
||||
confluence_client, group_name
|
||||
)
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=group_member_emails
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name, user_ids=[user.id for user in group_members]
|
||||
)
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
|
@ -1,4 +1,6 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@ -8,15 +10,17 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
@ -27,6 +31,42 @@ add_retries = retry_builder(tries=5, delay=5, max_delay=30)
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_docs_with_additional_info(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, Any]:
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
|
||||
if cc_pair.last_time_perm_sync
|
||||
else 0.0
|
||||
)
|
||||
cc_pair.last_time_perm_sync = current_time
|
||||
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
)
|
||||
|
||||
docs_with_additional_info = {
|
||||
doc.id: doc.additional_info
|
||||
for doc_batch in doc_batch_generator
|
||||
for doc in doc_batch
|
||||
}
|
||||
|
||||
return docs_with_additional_info
|
||||
|
||||
|
||||
def _fetch_permissions_paginated(
|
||||
drive_service: Any, drive_file_id: str
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
@ -122,8 +162,6 @@ def _fetch_google_permissions_for_document_id(
|
||||
def gdrive_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@ -131,10 +169,24 @@ def gdrive_doc_sync(
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
for doc in docs_with_additional_info:
|
||||
sync_details = cc_pair.auto_sync_options
|
||||
if sync_details is None:
|
||||
logger.error("Sync details not found for Google Drive")
|
||||
raise ValueError("Sync details not found for Google Drive")
|
||||
|
||||
# Here we run the connector to grab all the ids
|
||||
# this may grab ids before they are indexed but that is fine because
|
||||
# we create a document in postgres to hold the permissions info
|
||||
# until the indexing job has a chance to run
|
||||
docs_with_additional_info = _get_docs_with_additional_info(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
|
||||
for doc_id, doc_additional_info in docs_with_additional_info.items():
|
||||
ext_access = _fetch_google_permissions_for_document_id(
|
||||
db_session=db_session,
|
||||
drive_file_id=doc.additional_info,
|
||||
drive_file_id=doc_additional_info,
|
||||
raw_credentials_json=cc_pair.credential.credential_json,
|
||||
company_google_domains=[
|
||||
cast(dict[str, str], sync_details)["company_domain"]
|
||||
@ -142,7 +194,7 @@ def gdrive_doc_sync(
|
||||
)
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc.id,
|
||||
doc_id=doc_id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
|
@ -17,7 +17,6 @@ from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -105,9 +104,12 @@ def _fetch_group_members_paginated(
|
||||
def gdrive_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
sync_details = cc_pair.auto_sync_options
|
||||
if sync_details is None:
|
||||
logger.error("Sync details not found for Google Drive")
|
||||
raise ValueError("Sync details not found for Google Drive")
|
||||
|
||||
google_drive_creds, _ = get_google_drive_creds(
|
||||
cc_pair.credential.credential_json,
|
||||
scopes=FETCH_GROUPS_SCOPES,
|
||||
|
@ -4,32 +4,68 @@ from datetime import timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||
from danswer.document_index.factory import get_current_primary_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
DOC_PERMISSIONS_FUNC_MAP,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
FULL_FETCH_PERIOD_IN_SECONDS,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
GROUP_PERMISSIONS_FUNC_MAP,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_utils import (
|
||||
get_docs_with_additional_info,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_permission_sync_entrypoint(
|
||||
# None means that the connector runs every time
|
||||
_RESTRICTED_FETCH_PERIOD: dict[DocumentSource, int | None] = {
|
||||
# Polling is supported
|
||||
DocumentSource.GOOGLE_DRIVE: None,
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
}
|
||||
|
||||
|
||||
def run_external_group_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
|
||||
if group_sync_func is None:
|
||||
# Not all sync connectors support group permissions so this is fine
|
||||
return
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing groups for {source_type}")
|
||||
if group_sync_func is not None:
|
||||
group_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
)
|
||||
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document index: {e}")
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def run_external_doc_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
# TODO: seperate out group and doc sync
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
@ -37,20 +73,16 @@ def run_permission_sync_entrypoint(
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No permission sync function found for source type: {source_type}"
|
||||
)
|
||||
|
||||
sync_details = cc_pair.auto_sync_options
|
||||
if sync_details is None:
|
||||
raise ValueError(f"No auto sync options found for source type: {source_type}")
|
||||
|
||||
# If the source type is not polling, we only fetch the permissions every
|
||||
# _FULL_FETCH_PERIOD_IN_SECONDS seconds
|
||||
full_fetch_period = FULL_FETCH_PERIOD_IN_SECONDS[source_type]
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
# If RESTRICTED_FETCH_PERIOD is not None, we only run sync if the
|
||||
# last sync was more than RESTRICTED_FETCH_PERIOD seconds ago.
|
||||
full_fetch_period = _RESTRICTED_FETCH_PERIOD[cc_pair.connector.source]
|
||||
if full_fetch_period is not None:
|
||||
last_sync = cc_pair.last_time_perm_sync
|
||||
if (
|
||||
@ -62,65 +94,45 @@ def run_permission_sync_entrypoint(
|
||||
):
|
||||
return
|
||||
|
||||
# Here we run the connector to grab all the ids
|
||||
# this may grab ids before they are indexed but that is fine because
|
||||
# we create a document in postgres to hold the permissions info
|
||||
# until the indexing job has a chance to run
|
||||
docs_with_additional_info = get_docs_with_additional_info(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing groups for {source_type}")
|
||||
if group_sync_func is not None:
|
||||
group_sync_func(
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing docs for {source_type}")
|
||||
doc_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
docs_with_additional_info,
|
||||
sync_details,
|
||||
)
|
||||
|
||||
# This function updates:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing docs for {source_type}")
|
||||
doc_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
docs_with_additional_info,
|
||||
sync_details,
|
||||
)
|
||||
# Get the document ids for the cc pair
|
||||
document_ids_for_cc_pair = get_document_ids_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# This function fetches the updated access for the documents
|
||||
# and returns a dictionary of document_ids and access
|
||||
# This is the access we want to update vespa with
|
||||
docs_access = get_access_for_documents(
|
||||
document_ids=[doc.id for doc in docs_with_additional_info],
|
||||
db_session=db_session,
|
||||
)
|
||||
# This function fetches the updated access for the documents
|
||||
# and returns a dictionary of document_ids and access
|
||||
# This is the access we want to update vespa with
|
||||
docs_access = get_access_for_documents(
|
||||
document_ids=document_ids_for_cc_pair,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Then we build the update requests to update vespa
|
||||
update_reqs = [
|
||||
UpdateRequest(document_ids=[doc_id], access=doc_access)
|
||||
for doc_id, doc_access in docs_access.items()
|
||||
]
|
||||
# Then we build the update requests to update vespa
|
||||
update_reqs = [
|
||||
UpdateRequest(document_ids=[doc_id], access=doc_access)
|
||||
for doc_id, doc_access in docs_access.items()
|
||||
]
|
||||
|
||||
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
|
||||
document_index = get_current_primary_default_document_index(db_session)
|
||||
|
||||
try:
|
||||
# update vespa
|
||||
document_index.update(update_reqs)
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document index: {e}")
|
||||
logger.error(f"Error Syncing Permissions: {e}")
|
||||
db_session.rollback()
|
||||
|
@ -1,5 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -9,15 +8,14 @@ from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_s
|
||||
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
|
||||
None,
|
||||
]
|
||||
|
||||
DocSyncFuncType = Callable[
|
||||
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
|
||||
# Defining the input/output types for the sync functions
|
||||
SyncFuncType = Callable[
|
||||
[
|
||||
Session,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
None,
|
||||
]
|
||||
|
||||
@ -26,7 +24,7 @@ DocSyncFuncType = Callable[
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
||||
}
|
||||
@ -35,20 +33,11 @@ DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_group_sync,
|
||||
}
|
||||
|
||||
|
||||
# None means that the connector supports polling from last_time_perm_sync to now
|
||||
FULL_FETCH_PERIOD_IN_SECONDS: dict[DocumentSource, int | None] = {
|
||||
# Polling is supported
|
||||
DocumentSource.GOOGLE_DRIVE: None,
|
||||
# Polling is not supported so we fetch all doc permissions every 10 minutes
|
||||
DocumentSource.CONFLUENCE: 10 * 60,
|
||||
}
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
|
@ -1,56 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocsWithAdditionalInfo(BaseModel):
|
||||
id: str
|
||||
additional_info: Any
|
||||
|
||||
|
||||
def get_docs_with_additional_info(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocsWithAdditionalInfo]:
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
|
||||
if cc_pair.last_time_perm_sync
|
||||
else 0
|
||||
)
|
||||
cc_pair.last_time_perm_sync = current_time
|
||||
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
)
|
||||
|
||||
docs_with_additional_info = [
|
||||
DocsWithAdditionalInfo(id=doc.id, additional_info=doc.additional_info)
|
||||
for doc_batch in doc_batch_generator
|
||||
for doc in doc_batch
|
||||
]
|
||||
logger.debug(f"Docs with additional info: {len(docs_with_additional_info)}")
|
||||
|
||||
return docs_with_additional_info
|
@ -5,12 +5,9 @@ import {
|
||||
ConfigurableSources,
|
||||
validAutoSyncSources,
|
||||
} from "@/lib/types";
|
||||
import { Text, Title } from "@tremor/react";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { useField } from "formik";
|
||||
import { AutoSyncOptions } from "./AutoSyncOptions";
|
||||
import { useContext } from "react";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
|
||||
function isValidAutoSyncSource(
|
||||
@ -28,7 +25,6 @@ export function AccessTypeForm({
|
||||
useField<AccessType>("access_type");
|
||||
|
||||
const isPaidEnterpriseEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const settings = useContext(SettingsContext);
|
||||
const isAutoSyncSupported = isValidAutoSyncSource(connector);
|
||||
const { isLoadingUser, isAdmin } = useUser();
|
||||
|
||||
@ -85,7 +81,7 @@ export function AccessTypeForm({
|
||||
/>
|
||||
|
||||
{access_type.value === "sync" && isAutoSyncSupported && (
|
||||
<div className="mt-6">
|
||||
<div>
|
||||
<AutoSyncOptions
|
||||
connectorType={connector as ValidAutoSyncSources}
|
||||
/>
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { useFormikContext } from "formik";
|
||||
import { ValidAutoSyncSources } from "@/lib/types";
|
||||
import { Divider } from "@tremor/react";
|
||||
import { autoSyncConfigBySource } from "@/lib/connectors/AutoSyncOptionFields";
|
||||
@ -9,22 +8,24 @@ export function AutoSyncOptions({
|
||||
}: {
|
||||
connectorType: ValidAutoSyncSources;
|
||||
}) {
|
||||
const autoSyncConfig = autoSyncConfigBySource[connectorType];
|
||||
|
||||
if (Object.keys(autoSyncConfig).length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Divider />
|
||||
<>
|
||||
{Object.entries(autoSyncConfigBySource[connectorType]).map(
|
||||
([key, config]) => (
|
||||
<div key={key} className="mb-4">
|
||||
<TextFormField
|
||||
name={`auto_sync_options.${key}`}
|
||||
label={config.label}
|
||||
subtext={config.subtext}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
</>
|
||||
{Object.entries(autoSyncConfig).map(([key, config]) => (
|
||||
<div key={key} className="mb-4">
|
||||
<TextFormField
|
||||
name={`auto_sync_options.${key}`}
|
||||
label={config.label}
|
||||
subtext={config.subtext}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ export const autoSyncConfigBySource: Record<
|
||||
}
|
||||
>
|
||||
> = {
|
||||
confluence: {},
|
||||
google_drive: {
|
||||
customer_id: {
|
||||
label: "Google Workspace Customer ID",
|
||||
|
@ -264,5 +264,5 @@ export type ConfigurableSources = Exclude<
|
||||
>;
|
||||
|
||||
// The sources that have auto-sync support on the backend
|
||||
export const validAutoSyncSources = ["google_drive"] as const;
|
||||
export const validAutoSyncSources = ["confluence", "google_drive"] as const;
|
||||
export type ValidAutoSyncSources = (typeof validAutoSyncSources)[number];
|
||||
|
Loading…
x
Reference in New Issue
Block a user