Confluence polish (#2874)

This commit is contained in:
hagen-danswer
2024-10-22 13:41:47 -07:00
committed by GitHub
parent e031576c87
commit 914da2e4cb
6 changed files with 155 additions and 295 deletions

View File

@@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
from datetime import timezone from datetime import timezone
from typing import Any from typing import Any
from urllib.parse import quote
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@@ -8,6 +9,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import attachment_to_content from danswer.connectors.confluence.utils import attachment_to_content
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import build_confluence_document_id from danswer.connectors.confluence.utils import build_confluence_document_id
from danswer.connectors.confluence.utils import datetime_from_string from danswer.connectors.confluence.utils import datetime_from_string
from danswer.connectors.confluence.utils import extract_text_from_confluence_html from danswer.connectors.confluence.utils import extract_text_from_confluence_html
@@ -74,20 +76,21 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.wiki_base = wiki_base.rstrip("/") self.wiki_base = wiki_base.rstrip("/")
# if nothing is provided, we will fetch all pages # if nothing is provided, we will fetch all pages
self.cql_page_query = "type=page" cql_page_query = "type=page"
if cql_query: if cql_query:
# if a cql_query is provided, we will use it to fetch the pages # if a cql_query is provided, we will use it to fetch the pages
self.cql_page_query = cql_query cql_page_query = cql_query
elif space: elif space:
# if no cql_query is provided, we will use the space to fetch the pages # if no cql_query is provided, we will use the space to fetch the pages
self.cql_page_query += f" and space='{space}'" cql_page_query += f" and space='{quote(space)}'"
elif page_id: elif page_id:
if index_recursively: if index_recursively:
self.cql_page_query += f" and ancestor='{page_id}'" cql_page_query += f" and ancestor='{page_id}'"
else: else:
# if neither a space nor a cql_query is provided, we will use the page_id to fetch the page # if neither a space nor a cql_query is provided, we will use the page_id to fetch the page
self.cql_page_query += f" and id='{page_id}'" cql_page_query += f" and id='{page_id}'"
self.cql_page_query = cql_page_query
self.cql_label_filter = "" self.cql_label_filter = ""
self.cql_time_filter = "" self.cql_time_filter = ""
if labels_to_skip: if labels_to_skip:
@@ -96,19 +99,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = f"&label not in ({comma_separated_labels})" self.cql_label_filter = f"&label not in ({comma_separated_labels})"
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
username = credentials["confluence_username"]
access_token = credentials["confluence_access_token"]
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py # see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args # for a list of other hidden constructor args
self.confluence_client = OnyxConfluence( self.confluence_client = build_confluence_client(
url=self.wiki_base, credentials_json=credentials,
username=username if self.is_cloud else None, is_cloud=self.is_cloud,
password=access_token if self.is_cloud else None, wiki_base=self.wiki_base,
token=access_token if not self.is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
) )
return None return None
@@ -202,12 +198,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
# Fetch pages as Documents # Fetch pages as Documents
for pages in self.confluence_client.paginated_cql_page_retrieval( for page_batch in self.confluence_client.paginated_cql_page_retrieval(
cql=page_query, cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS), expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size, limit=self.batch_size,
): ):
for page in pages: for page in page_batch:
confluence_page_ids.append(page["id"]) confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page) doc = self._convert_object_to_document(page)
if doc is not None: if doc is not None:

View File

@@ -5,6 +5,7 @@ from collections.abc import Iterator
from typing import Any from typing import Any
from typing import cast from typing import cast
from typing import TypeVar from typing import TypeVar
from urllib.parse import quote
from atlassian import Confluence # type:ignore from atlassian import Confluence # type:ignore
from requests import HTTPError from requests import HTTPError
@@ -111,7 +112,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call) return cast(F, wrapped_call)
_PAGINATION_LIMIT = 100 _DEFAULT_PAGINATION_LIMIT = 100
class OnyxConfluence(Confluence): class OnyxConfluence(Confluence):
@@ -138,35 +139,62 @@ class OnyxConfluence(Confluence):
handle_confluence_rate_limit(getattr(self, attr_name)), handle_confluence_rate_limit(getattr(self, attr_name)),
) )
def _paginate_url(
self, url_suffix: str, limit: int | None = None
) -> Iterator[list[dict[str, Any]]]:
"""
This will paginate through the top level query.
"""
if not limit:
limit = _DEFAULT_PAGINATION_LIMIT
connection_char = "&" if "?" in url_suffix else "?"
url_suffix += f"{connection_char}limit={limit}"
while url_suffix:
try:
next_response = self.get(url_suffix)
except Exception as e:
logger.exception("Error in danswer_cql: \n")
raise e
yield next_response.get("results", [])
url_suffix = next_response.get("_links", {}).get("next")
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
return self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
group_name = quote(group_name)
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def paginated_cql_user_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
def paginated_cql_page_retrieval( def paginated_cql_page_retrieval(
self, self,
cql: str, cql: str,
expand: str | None = None, expand: str | None = None,
limit: int | None = None, limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]: ) -> Iterator[list[dict[str, Any]]]:
""" expand_string = f"&expand={expand}" if expand else ""
This will paginate through the top level query. return self._paginate_url(
""" f"rest/api/content/search?cql={cql}{expand_string}", limit
url_suffix = f"rest/api/content/search?cql={cql}" )
if expand:
url_suffix += f"&expand={expand}"
if not limit:
limit = _PAGINATION_LIMIT
url_suffix += f"&limit={limit}"
while True:
try:
response = self.get(url_suffix)
results = response["results"]
except Exception as e:
logger.exception("Error in danswer_cql: \n")
raise e
yield results
url_suffix = response.get("_links", {}).get("next")
if not url_suffix:
break
def cql_paginate_all_expansions( def cql_paginate_all_expansions(
self, self,
@@ -185,10 +213,7 @@ class OnyxConfluence(Confluence):
if isinstance(data, dict): if isinstance(data, dict):
next_url = data.get("_links", {}).get("next") next_url = data.get("_links", {}).get("next")
if next_url and "results" in data: if next_url and "results" in data:
while next_url: data["results"].extend(self._paginate_url(next_url))
next_response = self.get(next_url)
data["results"].extend(next_response.get("results", []))
next_url = next_response.get("_links", {}).get("next")
for value in data.values(): for value in data.values():
_traverse_and_update(value) _traverse_and_update(value)
@@ -199,113 +224,3 @@ class OnyxConfluence(Confluence):
for results in self.paginated_cql_page_retrieval(cql, expand, limit): for results in self.paginated_cql_page_retrieval(cql, expand, limit):
_traverse_and_update(results) _traverse_and_update(results)
yield results yield results
# commenting out while we try using confluence's rate limiter instead
# # https://developer.atlassian.com/cloud/confluence/rate-limiting/
# def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
# def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
# max_retries = 5
# starting_delay = 5
# backoff = 2
# # max_delay is used when the server doesn't hand back "Retry-After"
# # and we have to decide the retry delay ourselves
# max_delay = 30 # Atlassian uses max_delay = 30 in their examples
# # max_retry_after is used when we do get a "Retry-After" header
# max_retry_after = 300 # should we really cap the maximum retry delay?
# NEXT_RETRY_KEY = BaseConnector.REDIS_KEY_PREFIX + "confluence_next_retry"
# # for testing purposes, rate limiting is written to fall back to a simpler
# # rate limiting approach when redis is not available
# r = get_redis_client()
# for attempt in range(max_retries):
# try:
# # if multiple connectors are waiting for the next attempt, there could be an issue
# # where many connectors are "released" onto the server at the same time.
# # That's not ideal ... but coming up with a mechanism for queueing
# # all of these connectors is a bigger problem that we want to take on
# # right now
# try:
# next_attempt = r.get(NEXT_RETRY_KEY)
# if next_attempt is None:
# next_attempt = 0
# else:
# next_attempt = int(cast(int, next_attempt))
# # TODO: all connectors need to be interruptible moving forward
# while time.monotonic() < next_attempt:
# time.sleep(1)
# except ConnectionError:
# pass
# return confluence_call(*args, **kwargs)
# except HTTPError as e:
# # Check if the response or headers are None to avoid potential AttributeError
# if e.response is None or e.response.headers is None:
# logger.warning("HTTPError with `None` as response or as headers")
# raise e
# retry_after_header = e.response.headers.get("Retry-After")
# if (
# e.response.status_code == 429
# or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
# ):
# retry_after = None
# if retry_after_header is not None:
# try:
# retry_after = int(retry_after_header)
# except ValueError:
# pass
# if retry_after is not None:
# if retry_after > max_retry_after:
# logger.warning(
# f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
# )
# retry_after = max_delay
# logger.warning(
# f"Rate limit hit. Retrying after {retry_after} seconds..."
# )
# try:
# r.set(
# NEXT_RETRY_KEY,
# math.ceil(time.monotonic() + retry_after),
# )
# except ConnectionError:
# pass
# else:
# logger.warning(
# "Rate limit hit. Retrying with exponential backoff..."
# )
# delay = min(starting_delay * (backoff**attempt), max_delay)
# delay_until = math.ceil(time.monotonic() + delay)
# try:
# r.set(NEXT_RETRY_KEY, delay_until)
# except ConnectionError:
# while time.monotonic() < delay_until:
# time.sleep(1)
# else:
# # re-raise, let caller handle
# raise
# except AttributeError as e:
# # Some error within the Confluence library, unclear why it fails.
# # Users reported it to be intermittent, so just retry
# logger.warning(f"Confluence Internal Error, retrying... {e}")
# delay = min(starting_delay * (backoff**attempt), max_delay)
# delay_until = math.ceil(time.monotonic() + delay)
# try:
# r.set(NEXT_RETRY_KEY, delay_until)
# except ConnectionError:
# while time.monotonic() < delay_until:
# time.sleep(1)
# if attempt == max_retries - 1:
# raise e
# return cast(F, wrapped_call)

View File

@@ -18,7 +18,25 @@ from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
_USER_NOT_FOUND = "Unknown User"
_USER_EMAIL_CACHE: dict[str, str | None] = {}
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
email = None
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str] = {} _USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str] = {}
@@ -32,19 +50,22 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
Returns: Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found str: The User Display Name. 'Unknown User' if the user is deactivated or not found
""" """
# Cache hit global _USER_ID_TO_DISPLAY_NAME_CACHE
if user_id in _USER_ID_TO_DISPLAY_NAME_CACHE: if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
return _USER_ID_TO_DISPLAY_NAME_CACHE[user_id] try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
try: if not found_display_name:
result = confluence_client.get_user_details_by_accountid(user_id) try:
if found_display_name := result.get("displayName"): result = confluence_client.get_user_details_by_accountid(user_id)
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name found_display_name = result.get("displayName")
except Exception: except Exception:
# may need to just not log this error but will leave here for now found_display_name = None
logger.exception(
f"Unable to get the User Display Name with the id: '{user_id}'" _USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
)
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id, _USER_NOT_FOUND) return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id, _USER_NOT_FOUND)
@@ -174,3 +195,20 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc) datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object return datetime_object
def build_confluence_client(
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
) -> OnyxConfluence:
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)

View File

@@ -8,15 +8,13 @@ from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess from danswer.access.models import ExternalAccess
from danswer.connectors.confluence.connector import ConfluenceConnector from danswer.connectors.confluence.connector import ConfluenceConnector
from danswer.connectors.confluence.connector import OnyxConfluence from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import get_user_email_from_username__server
from danswer.connectors.models import SlimDocument from danswer.connectors.models import SlimDocument
from danswer.db.models import ConnectorCredentialPair from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit from ee.danswer.db.document import upsert_document_external_perms__no_commit
from ee.danswer.external_permissions.confluence.sync_utils import (
get_user_email_from_username__server,
)
logger = setup_logger() logger = setup_logger()
@@ -244,11 +242,10 @@ def confluence_doc_sync(
confluence_client=confluence_client, confluence_client=confluence_client,
is_cloud=is_cloud, is_cloud=is_cloud,
) )
slim_docs = [
slim_doc slim_docs = []
for doc_batch in confluence_connector.retrieve_all_slim_documents() for doc_batch in confluence_connector.retrieve_all_slim_documents():
for slim_doc in doc_batch slim_docs.extend(doc_batch)
]
permissions_by_doc_id = _fetch_all_page_restrictions_for_space( permissions_by_doc_id = _fetch_all_page_restrictions_for_space(
confluence_client=confluence_client, confluence_client=confluence_client,

View File

@@ -1,91 +1,40 @@
from collections.abc import Iterator from typing import Any
from atlassian import Confluence # type:ignore
from requests import HTTPError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.connectors.confluence.onyx_confluence import ( from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
handle_confluence_rate_limit, from danswer.connectors.confluence.utils import build_confluence_client
) from danswer.connectors.confluence.utils import get_user_email_from_username__server
from danswer.db.models import ConnectorCredentialPair from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import ExternalUserGroup from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.confluence.sync_utils import (
build_confluence_client,
)
from ee.danswer.external_permissions.confluence.sync_utils import (
get_user_email_from_username__server,
)
logger = setup_logger() logger = setup_logger()
_PAGE_SIZE = 100
def _get_confluence_group_names_paginated(
confluence_client: Confluence,
) -> Iterator[str]:
get_all_groups = handle_confluence_rate_limit(confluence_client.get_all_groups)
start = 0
while True:
try:
groups = get_all_groups(start=start, limit=_PAGE_SIZE)
except HTTPError as e:
if e.response.status_code in (403, 404):
return
raise e
for group in groups:
if group_name := group.get("name"):
yield group_name
if len(groups) < _PAGE_SIZE:
break
start += _PAGE_SIZE
def _get_group_members_email_paginated( def _get_group_members_email_paginated(
confluence_client: Confluence, confluence_client: OnyxConfluence,
group_name: str, group_name: str,
is_cloud: bool,
) -> set[str]: ) -> set[str]:
get_group_members = handle_confluence_rate_limit( members: list[dict[str, Any]] = []
confluence_client.get_group_members for member_batch in confluence_client.paginated_group_members_retrieval(group_name):
) members.extend(member_batch)
group_member_emails: set[str] = set() group_member_emails: set[str] = set()
start = 0 for member in members:
while True: email = member.get("email")
try: if not email:
members = get_group_members( user_name = member.get("username")
group_name=group_name, start=start, limit=_PAGE_SIZE if user_name:
)
except HTTPError as e:
if e.response.status_code == 403 or e.response.status_code == 404:
return group_member_emails
raise e
for member in members:
if is_cloud:
email = member.get("email")
elif user_name := member.get("username"):
email = get_user_email_from_username__server( email = get_user_email_from_username__server(
confluence_client, user_name confluence_client=confluence_client,
user_name=user_name,
) )
else: if email:
logger.warning(f"Member has no email or username: {member}") group_member_emails.add(email)
email = None
if email:
group_member_emails.add(email)
if len(members) < _PAGE_SIZE:
break
start += _PAGE_SIZE
return group_member_emails return group_member_emails
@@ -94,17 +43,25 @@ def confluence_group_sync(
db_session: Session, db_session: Session,
cc_pair: ConnectorCredentialPair, cc_pair: ConnectorCredentialPair,
) -> None: ) -> None:
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
confluence_client = build_confluence_client( confluence_client = build_confluence_client(
connector_specific_config=cc_pair.connector.connector_specific_config,
credentials_json=cc_pair.credential.credential_json, credentials_json=cc_pair.credential.credential_json,
is_cloud=is_cloud,
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
) )
# Get all group names
group_names: list[str] = []
for group_batch in confluence_client.paginated_groups_retrieval():
for group in group_batch:
if group_name := group.get("name"):
group_names.append(group_name)
# For each group name, get all members and create a danswer group
danswer_groups: list[ExternalUserGroup] = [] danswer_groups: list[ExternalUserGroup] = []
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) for group_name in group_names:
# Confluence enforces that group names are unique
for group_name in _get_confluence_group_names_paginated(confluence_client):
group_member_emails = _get_group_members_email_paginated( group_member_emails = _get_group_members_email_paginated(
confluence_client, group_name, is_cloud confluence_client, group_name
) )
group_members = batch_add_non_web_user_if_not_exists__no_commit( group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=list(group_member_emails) db_session=db_session, emails=list(group_member_emails)

View File

@@ -1,43 +0,0 @@
from typing import Any
from danswer.connectors.confluence.connector import OnyxConfluence
from danswer.connectors.confluence.onyx_confluence import (
handle_confluence_rate_limit,
)
_USER_EMAIL_CACHE: dict[str, str | None] = {}
def build_confluence_client(
connector_specific_config: dict[str, Any], credentials_json: dict[str, Any]
) -> OnyxConfluence:
is_cloud = connector_specific_config.get("is_cloud", False)
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=connector_specific_config["wiki_base"].rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
get_user_info = handle_confluence_rate_limit(
confluence_client.get_mobile_parameters
)
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = get_user_info(user_name)
email = response.get("email")
except Exception:
email = None
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]