Files
danswer/backend/onyx/connectors/confluence/onyx_confluence.py
Chris Weaver f767b1f476 Fix confluence permission syncing at scale (#4129)
* Fix confluence permission syncing at scale

* Remove line

* Better log message

* Adjust log
2025-02-25 19:22:52 -08:00

570 lines
22 KiB
Python

import math
import time
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import TypeVar
from urllib.parse import quote
from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from requests import HTTPError
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
class ConfluenceRateLimitError(Exception):
pass
class ConfluenceUser(BaseModel):
user_id: str # accountId in Cloud, userKey in Server
username: str | None # Confluence Cloud doesn't give usernames
display_name: str
# Confluence Data Center doesn't give email back by default,
# have to fetch it with a different endpoint
email: str | None
type: str
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
class OnyxConfluence(Confluence):
"""
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
This is necessary because the default Confluence class does not properly support cql expansions.
All methods are automatically wrapped with handle_confluence_rate_limit.
"""
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def _wrap_methods(self) -> None:
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
wrap it with handle_confluence_rate_limit.
"""
for attr_name in dir(self):
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
setattr(
self,
attr_name,
handle_confluence_rate_limit(getattr(self, attr_name)),
)
def _paginate_url(
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
"""
if not limit:
limit = _DEFAULT_PAGINATION_LIMIT
connection_char = "&" if "?" in url_suffix else "?"
url_suffix += f"{connection_char}limit={limit}"
while url_suffix:
logger.debug(f"Making confluence call to {url_suffix}")
try:
raw_response = self.get(
path=url_suffix,
advanced_mode=True,
)
except Exception as e:
logger.exception(f"Error in confluence call to {url_suffix}")
raise e
try:
raw_response.raise_for_status()
except Exception as e:
logger.warning(f"Error in confluence call to {url_suffix}")
# If the problematic expansion is in the url, replace it
# with the replacement expansion and try again
# If that fails, raise the error
if _PROBLEMATIC_EXPANSIONS in url_suffix:
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
if (
raw_response.status_code == 500
and limit > _MINIMUM_PAGINATION_LIMIT
):
new_limit = limit // 2
logger.warning(
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
f"Reducing limit from {limit} to {new_limit} and trying again."
)
url_suffix = url_suffix.replace(
f"limit={limit}", f"limit={new_limit}"
)
limit = new_limit
continue
logger.exception(
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
)
raise e
try:
next_response = raw_response.json()
except Exception as e:
logger.exception(
f"Failed to parse response as JSON. Response: {raw_response.__dict__}"
)
raise e
# yield the results individually
results = cast(list[dict[str, Any]], next_response.get("results", []))
yield from results
old_url_suffix = url_suffix
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
# make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
if url_suffix and "start" in url_suffix:
new_start = get_start_param_from_url(url_suffix)
previous_start = get_start_param_from_url(old_url_suffix)
if new_start - previous_start > len(results):
logger.warning(
f"Start was updated by more than the amount of results "
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
f"Previous Start: {previous_start}, Len Results: {len(results)}."
)
# Update the url_suffix to use the adjusted start
adjusted_start = previous_start + len(results)
url_suffix = update_param_in_path(
url_suffix, "start", str(adjusted_start)
)
# some APIs don't properly paginate, so we need to manually update the `start` param
if auto_paginate and len(results) > 0:
previous_start = get_start_param_from_url(old_url_suffix)
updated_start = previous_start + len(results)
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
def paginated_cql_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
The content/search endpoint can be used to fetch pages, attachments, and comments.
"""
expand_string = f"&expand={expand}" if expand else ""
yield from self._paginate_url(
f"rest/api/content/search?cql={cql}{expand_string}", limit
)
def cql_paginate_all_expansions(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This function will paginate through the top level query first, then
paginate through all of the expansions.
The limit only applies to the top level query.
All expansion paginations use default pagination limit (defined by Atlassian).
"""
def _traverse_and_update(data: dict | list) -> None:
if isinstance(data, dict):
next_url = data.get("_links", {}).get("next")
if next_url and "results" in data:
data["results"].extend(self._paginate_url(next_url))
for value in data.values():
_traverse_and_update(value)
elif isinstance(data, list):
for item in data:
_traverse_and_update(item)
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
_traverse_and_update(confluence_object)
yield confluence_object
def paginated_cql_user_retrieval(
self,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[ConfluenceUser]:
"""
The search/user endpoint can be used to fetch users.
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
if self.cloud:
cql = "type=user"
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
# endpoint doesn't properly paginate, so we need to manually update the `start` param
# thus the auto_paginate flag
for user_result in self._paginate_url(url, limit, auto_paginate=True):
# Example response:
# {
# 'user': {
# 'type': 'known',
# 'accountId': '712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
# 'accountType': 'atlassian',
# 'email': 'chris@danswer.ai',
# 'publicName': 'Chris Weaver',
# 'profilePicture': {
# 'path': '/wiki/aa-avatar/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
# 'width': 48,
# 'height': 48,
# 'isDefault': False
# },
# 'displayName': 'Chris Weaver',
# 'isExternalCollaborator': False,
# '_expandable': {
# 'operations': '',
# 'personalSpace': ''
# },
# '_links': {
# 'self': 'https://danswerai.atlassian.net/wiki/rest/api/user?accountId=712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d'
# }
# },
# 'title': 'Chris Weaver',
# 'excerpt': '',
# 'url': '/people/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
# 'breadcrumbs': [],
# 'entityType': 'user',
# 'iconCssClass': 'aui-icon content-type-profile',
# 'lastModified': '2025-02-18T04:08:03.579Z',
# 'score': 0.0
# }
user = user_result["user"]
yield ConfluenceUser(
user_id=user["accountId"],
username=None,
display_name=user["displayName"],
email=user.get("email"),
type=user["accountType"],
)
else:
# https://developer.atlassian.com/server/confluence/rest/v900/api-group-user/#api-rest-api-user-list-get
# ^ is only available on data center deployments
# Example response:
# [
# {
# 'type': 'known',
# 'username': 'admin',
# 'userKey': '40281082950c5fe901950c61c55d0000',
# 'profilePicture': {
# 'path': '/images/icons/profilepics/default.svg',
# 'width': 48,
# 'height': 48,
# 'isDefault': True
# },
# 'displayName': 'Admin Test',
# '_links': {
# 'self': 'http://localhost:8090/rest/api/user?key=40281082950c5fe901950c61c55d0000'
# },
# '_expandable': {
# 'status': ''
# }
# }
# ]
for user in self._paginate_url("rest/api/user/list", limit):
yield ConfluenceUser(
user_id=user["userKey"],
username=user["username"],
display_name=user["displayName"],
email=None,
type=user.get("type", "user"),
)
def paginated_groups_by_user_retrieval(
self,
user_id: str, # accountId in Cloud, userKey in Server
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
user_field = "accountId" if self.cloud else "key"
user_value = user_id
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit)
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
yield from self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch the members of a group.
THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name.
E.g. neither "test/group" nor "test%2Fgroup" works for confluence.
"""
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def get_all_space_permissions_server(
self,
space_key: str,
) -> list[dict[str, Any]]:
"""
This is a confluence server specific method that can be used to
fetch the permissions of a space.
This is better logging than calling the get_space_permissions method
because it returns a jsonrpc response.
TODO: Make this call these endpoints for newer confluence versions:
- /rest/api/space/{spaceKey}/permissions
- /rest/api/space/{spaceKey}/permissions/anonymous
"""
url = "rpc/json-rpc/confluenceservice-v2"
data = {
"jsonrpc": "2.0",
"method": "getSpacePermissionSets",
"id": 7,
"params": [space_key],
}
response = self.post(url, data=data)
logger.debug(f"jsonrpc response: {response}")
if not response.get("result"):
logger.warning(
f"No jsonrpc response for space permissions for space {space_key}"
f"\nResponse: {response}"
)
return response.get("result", [])
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_with_minimal_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=6,
max_backoff_seconds=10,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
try:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
except Exception as e:
raise ConnectorValidationError(str(e))
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
cloud=is_cloud,
)