Confluence polish

This commit is contained in:
hagen-danswer 2024-10-22 13:21:07 -07:00
parent 6e9b6a1075
commit 1a019dd6d5
6 changed files with 155 additions and 295 deletions

View File

@ -1,6 +1,7 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import quote
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@ -8,6 +9,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import attachment_to_content
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import build_confluence_document_id
from danswer.connectors.confluence.utils import datetime_from_string
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
@ -74,20 +76,21 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.wiki_base = wiki_base.rstrip("/")
# if nothing is provided, we will fetch all pages
self.cql_page_query = "type=page"
cql_page_query = "type=page"
if cql_query:
# if a cql_query is provided, we will use it to fetch the pages
self.cql_page_query = cql_query
cql_page_query = cql_query
elif space:
# if no cql_query is provided, we will use the space to fetch the pages
self.cql_page_query += f" and space='{space}'"
cql_page_query += f" and space='{quote(space)}'"
elif page_id:
if index_recursively:
self.cql_page_query += f" and ancestor='{page_id}'"
cql_page_query += f" and ancestor='{page_id}'"
else:
# if neither a space nor a cql_query is provided, we will use the page_id to fetch the page
self.cql_page_query += f" and id='{page_id}'"
cql_page_query += f" and id='{page_id}'"
self.cql_page_query = cql_page_query
self.cql_label_filter = ""
self.cql_time_filter = ""
if labels_to_skip:
@ -96,19 +99,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = f"&label not in ({comma_separated_labels})"
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
username = credentials["confluence_username"]
access_token = credentials["confluence_access_token"]
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self.confluence_client = OnyxConfluence(
url=self.wiki_base,
username=username if self.is_cloud else None,
password=access_token if self.is_cloud else None,
token=access_token if not self.is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
self.confluence_client = build_confluence_client(
credentials_json=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
)
return None
@ -202,12 +198,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
# Fetch pages as Documents
for pages in self.confluence_client.paginated_cql_page_retrieval(
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
for page in pages:
for page in page_batch:
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:

View File

@ -5,6 +5,7 @@ from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import TypeVar
from urllib.parse import quote
from atlassian import Confluence # type:ignore
from requests import HTTPError
@ -111,7 +112,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
_PAGINATION_LIMIT = 100
_DEFAULT_PAGINATION_LIMIT = 100
class OnyxConfluence(Confluence):
@ -138,35 +139,62 @@ class OnyxConfluence(Confluence):
handle_confluence_rate_limit(getattr(self, attr_name)),
)
def _paginate_url(
self, url_suffix: str, limit: int | None = None
) -> Iterator[list[dict[str, Any]]]:
"""
This will paginate through the top level query.
"""
if not limit:
limit = _DEFAULT_PAGINATION_LIMIT
connection_char = "&" if "?" in url_suffix else "?"
url_suffix += f"{connection_char}limit={limit}"
while url_suffix:
try:
next_response = self.get(url_suffix)
except Exception as e:
logger.exception("Error in danswer_cql: \n")
raise e
yield next_response.get("results", [])
url_suffix = next_response.get("_links", {}).get("next")
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
return self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
group_name = quote(group_name)
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def paginated_cql_user_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
def paginated_cql_page_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
"""
This will paginate through the top level query.
"""
url_suffix = f"rest/api/content/search?cql={cql}"
if expand:
url_suffix += f"&expand={expand}"
if not limit:
limit = _PAGINATION_LIMIT
url_suffix += f"&limit={limit}"
while True:
try:
response = self.get(url_suffix)
results = response["results"]
except Exception as e:
logger.exception("Error in danswer_cql: \n")
raise e
yield results
url_suffix = response.get("_links", {}).get("next")
if not url_suffix:
break
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/content/search?cql={cql}{expand_string}", limit
)
def cql_paginate_all_expansions(
self,
@ -185,10 +213,7 @@ class OnyxConfluence(Confluence):
if isinstance(data, dict):
next_url = data.get("_links", {}).get("next")
if next_url and "results" in data:
while next_url:
next_response = self.get(next_url)
data["results"].extend(next_response.get("results", []))
next_url = next_response.get("_links", {}).get("next")
data["results"].extend(self._paginate_url(next_url))
for value in data.values():
_traverse_and_update(value)
@ -199,113 +224,3 @@ class OnyxConfluence(Confluence):
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
_traverse_and_update(results)
yield results
# commenting out while we try using confluence's rate limiter instead
# # https://developer.atlassian.com/cloud/confluence/rate-limiting/
# def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
# def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
# max_retries = 5
# starting_delay = 5
# backoff = 2
# # max_delay is used when the server doesn't hand back "Retry-After"
# # and we have to decide the retry delay ourselves
# max_delay = 30 # Atlassian uses max_delay = 30 in their examples
# # max_retry_after is used when we do get a "Retry-After" header
# max_retry_after = 300 # should we really cap the maximum retry delay?
# NEXT_RETRY_KEY = BaseConnector.REDIS_KEY_PREFIX + "confluence_next_retry"
# # for testing purposes, rate limiting is written to fall back to a simpler
# # rate limiting approach when redis is not available
# r = get_redis_client()
# for attempt in range(max_retries):
# try:
# # if multiple connectors are waiting for the next attempt, there could be an issue
# # where many connectors are "released" onto the server at the same time.
# # That's not ideal ... but coming up with a mechanism for queueing
# # all of these connectors is a bigger problem that we want to take on
# # right now
# try:
# next_attempt = r.get(NEXT_RETRY_KEY)
# if next_attempt is None:
# next_attempt = 0
# else:
# next_attempt = int(cast(int, next_attempt))
# # TODO: all connectors need to be interruptible moving forward
# while time.monotonic() < next_attempt:
# time.sleep(1)
# except ConnectionError:
# pass
# return confluence_call(*args, **kwargs)
# except HTTPError as e:
# # Check if the response or headers are None to avoid potential AttributeError
# if e.response is None or e.response.headers is None:
# logger.warning("HTTPError with `None` as response or as headers")
# raise e
# retry_after_header = e.response.headers.get("Retry-After")
# if (
# e.response.status_code == 429
# or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
# ):
# retry_after = None
# if retry_after_header is not None:
# try:
# retry_after = int(retry_after_header)
# except ValueError:
# pass
# if retry_after is not None:
# if retry_after > max_retry_after:
# logger.warning(
# f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
# )
# retry_after = max_delay
# logger.warning(
# f"Rate limit hit. Retrying after {retry_after} seconds..."
# )
# try:
# r.set(
# NEXT_RETRY_KEY,
# math.ceil(time.monotonic() + retry_after),
# )
# except ConnectionError:
# pass
# else:
# logger.warning(
# "Rate limit hit. Retrying with exponential backoff..."
# )
# delay = min(starting_delay * (backoff**attempt), max_delay)
# delay_until = math.ceil(time.monotonic() + delay)
# try:
# r.set(NEXT_RETRY_KEY, delay_until)
# except ConnectionError:
# while time.monotonic() < delay_until:
# time.sleep(1)
# else:
# # re-raise, let caller handle
# raise
# except AttributeError as e:
# # Some error within the Confluence library, unclear why it fails.
# # Users reported it to be intermittent, so just retry
# logger.warning(f"Confluence Internal Error, retrying... {e}")
# delay = min(starting_delay * (backoff**attempt), max_delay)
# delay_until = math.ceil(time.monotonic() + delay)
# try:
# r.set(NEXT_RETRY_KEY, delay_until)
# except ConnectionError:
# while time.monotonic() < delay_until:
# time.sleep(1)
# if attempt == max_retries - 1:
# raise e
# return cast(F, wrapped_call)

View File

@ -18,7 +18,25 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
_USER_NOT_FOUND = "Unknown User"
_USER_EMAIL_CACHE: dict[str, str | None] = {}
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
email = None
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str] = {}
@ -32,19 +50,22 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
"""
# Cache hit
if user_id in _USER_ID_TO_DISPLAY_NAME_CACHE:
return _USER_ID_TO_DISPLAY_NAME_CACHE[user_id]
global _USER_ID_TO_DISPLAY_NAME_CACHE
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
try:
result = confluence_client.get_user_details_by_accountid(user_id)
if found_display_name := result.get("displayName"):
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
except Exception:
# may need to just not log this error but will leave here for now
logger.exception(
f"Unable to get the User Display Name with the id: '{user_id}'"
)
if not found_display_name:
try:
result = confluence_client.get_user_details_by_accountid(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id, _USER_NOT_FOUND)
@ -174,3 +195,20 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def build_confluence_client(
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
) -> OnyxConfluence:
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)

View File

@ -8,15 +8,13 @@ from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.confluence.connector import ConfluenceConnector
from danswer.connectors.confluence.connector import OnyxConfluence
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import get_user_email_from_username__server
from danswer.connectors.models import SlimDocument
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit
from ee.danswer.external_permissions.confluence.sync_utils import (
get_user_email_from_username__server,
)
logger = setup_logger()
@ -244,11 +242,10 @@ def confluence_doc_sync(
confluence_client=confluence_client,
is_cloud=is_cloud,
)
slim_docs = [
slim_doc
for doc_batch in confluence_connector.retrieve_all_slim_documents()
for slim_doc in doc_batch
]
slim_docs = []
for doc_batch in confluence_connector.retrieve_all_slim_documents():
slim_docs.extend(doc_batch)
permissions_by_doc_id = _fetch_all_page_restrictions_for_space(
confluence_client=confluence_client,

View File

@ -1,91 +1,40 @@
from collections.abc import Iterator
from typing import Any
from atlassian import Confluence # type:ignore
from requests import HTTPError
from sqlalchemy.orm import Session
from danswer.connectors.confluence.onyx_confluence import (
handle_confluence_rate_limit,
)
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import get_user_email_from_username__server
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.confluence.sync_utils import (
build_confluence_client,
)
from ee.danswer.external_permissions.confluence.sync_utils import (
get_user_email_from_username__server,
)
logger = setup_logger()
_PAGE_SIZE = 100
def _get_confluence_group_names_paginated(
confluence_client: Confluence,
) -> Iterator[str]:
get_all_groups = handle_confluence_rate_limit(confluence_client.get_all_groups)
start = 0
while True:
try:
groups = get_all_groups(start=start, limit=_PAGE_SIZE)
except HTTPError as e:
if e.response.status_code in (403, 404):
return
raise e
for group in groups:
if group_name := group.get("name"):
yield group_name
if len(groups) < _PAGE_SIZE:
break
start += _PAGE_SIZE
def _get_group_members_email_paginated(
confluence_client: Confluence,
confluence_client: OnyxConfluence,
group_name: str,
is_cloud: bool,
) -> set[str]:
get_group_members = handle_confluence_rate_limit(
confluence_client.get_group_members
)
members: list[dict[str, Any]] = []
for member_batch in confluence_client.paginated_group_members_retrieval(group_name):
members.extend(member_batch)
group_member_emails: set[str] = set()
start = 0
while True:
try:
members = get_group_members(
group_name=group_name, start=start, limit=_PAGE_SIZE
)
except HTTPError as e:
if e.response.status_code == 403 or e.response.status_code == 404:
return group_member_emails
raise e
for member in members:
if is_cloud:
email = member.get("email")
elif user_name := member.get("username"):
for member in members:
email = member.get("email")
if not email:
user_name = member.get("username")
if user_name:
email = get_user_email_from_username__server(
confluence_client, user_name
confluence_client=confluence_client,
user_name=user_name,
)
else:
logger.warning(f"Member has no email or username: {member}")
email = None
if email:
group_member_emails.add(email)
if len(members) < _PAGE_SIZE:
break
start += _PAGE_SIZE
if email:
group_member_emails.add(email)
return group_member_emails
@ -94,17 +43,25 @@ def confluence_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> None:
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
confluence_client = build_confluence_client(
connector_specific_config=cc_pair.connector.connector_specific_config,
credentials_json=cc_pair.credential.credential_json,
is_cloud=is_cloud,
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
)
# Get all group names
group_names: list[str] = []
for group_batch in confluence_client.paginated_groups_retrieval():
for group in group_batch:
if group_name := group.get("name"):
group_names.append(group_name)
# For each group name, get all members and create a danswer group
danswer_groups: list[ExternalUserGroup] = []
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
# Confluence enforces that group names are unique
for group_name in _get_confluence_group_names_paginated(confluence_client):
for group_name in group_names:
group_member_emails = _get_group_members_email_paginated(
confluence_client, group_name, is_cloud
confluence_client, group_name
)
group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=list(group_member_emails)

View File

@ -1,43 +0,0 @@
from typing import Any
from danswer.connectors.confluence.connector import OnyxConfluence
from danswer.connectors.confluence.onyx_confluence import (
handle_confluence_rate_limit,
)
_USER_EMAIL_CACHE: dict[str, str | None] = {}
def build_confluence_client(
connector_specific_config: dict[str, Any], credentials_json: dict[str, Any]
) -> OnyxConfluence:
is_cloud = connector_specific_config.get("is_cloud", False)
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=connector_specific_config["wiki_base"].rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
get_user_info = handle_confluence_rate_limit(
confluence_client.get_mobile_parameters
)
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = get_user_info(user_name)
email = response.get("email")
except Exception:
email = None
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]