mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Fix confluence permission syncing at scale (#4129)
* Fix confluence permission syncing at scale * Remove line * Better log message * Adjust log
This commit is contained in:
@@ -11,6 +11,8 @@ from atlassian import Confluence # type:ignore
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
|
||||||
|
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||||
|
from onyx.connectors.confluence.utils import update_param_in_path
|
||||||
from onyx.connectors.exceptions import ConnectorValidationError
|
from onyx.connectors.exceptions import ConnectorValidationError
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
@@ -161,7 +163,7 @@ class OnyxConfluence(Confluence):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _paginate_url(
|
def _paginate_url(
|
||||||
self, url_suffix: str, limit: int | None = None
|
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
|
||||||
) -> Iterator[dict[str, Any]]:
|
) -> Iterator[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
This will paginate through the top level query.
|
This will paginate through the top level query.
|
||||||
@@ -236,9 +238,41 @@ class OnyxConfluence(Confluence):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
# yield the results individually
|
# yield the results individually
|
||||||
yield from next_response.get("results", [])
|
results = cast(list[dict[str, Any]], next_response.get("results", []))
|
||||||
|
yield from results
|
||||||
|
|
||||||
url_suffix = next_response.get("_links", {}).get("next")
|
old_url_suffix = url_suffix
|
||||||
|
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||||
|
|
||||||
|
# make sure we don't update the start by more than the amount
|
||||||
|
# of results we were able to retrieve. The Confluence API has a
|
||||||
|
# weird behavior where if you pass in a limit that is too large for
|
||||||
|
# the configured server, it will artificially limit the amount of
|
||||||
|
# results returned BUT will not apply this to the start parameter.
|
||||||
|
# This will cause us to miss results.
|
||||||
|
if url_suffix and "start" in url_suffix:
|
||||||
|
new_start = get_start_param_from_url(url_suffix)
|
||||||
|
previous_start = get_start_param_from_url(old_url_suffix)
|
||||||
|
if new_start - previous_start > len(results):
|
||||||
|
logger.warning(
|
||||||
|
f"Start was updated by more than the amount of results "
|
||||||
|
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
|
||||||
|
f"Previous Start: {previous_start}, Len Results: {len(results)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the url_suffix to use the adjusted start
|
||||||
|
adjusted_start = previous_start + len(results)
|
||||||
|
url_suffix = update_param_in_path(
|
||||||
|
url_suffix, "start", str(adjusted_start)
|
||||||
|
)
|
||||||
|
|
||||||
|
# some APIs don't properly paginate, so we need to manually update the `start` param
|
||||||
|
if auto_paginate and len(results) > 0:
|
||||||
|
previous_start = get_start_param_from_url(old_url_suffix)
|
||||||
|
updated_start = previous_start + len(results)
|
||||||
|
url_suffix = update_param_in_path(
|
||||||
|
old_url_suffix, "start", str(updated_start)
|
||||||
|
)
|
||||||
|
|
||||||
def paginated_cql_retrieval(
|
def paginated_cql_retrieval(
|
||||||
self,
|
self,
|
||||||
@@ -298,7 +332,9 @@ class OnyxConfluence(Confluence):
|
|||||||
url = "rest/api/search/user"
|
url = "rest/api/search/user"
|
||||||
expand_string = f"&expand={expand}" if expand else ""
|
expand_string = f"&expand={expand}" if expand else ""
|
||||||
url += f"?cql={cql}{expand_string}"
|
url += f"?cql={cql}{expand_string}"
|
||||||
for user_result in self._paginate_url(url, limit):
|
# endpoint doesn't properly paginate, so we need to manually update the `start` param
|
||||||
|
# thus the auto_paginate flag
|
||||||
|
for user_result in self._paginate_url(url, limit, auto_paginate=True):
|
||||||
# Example response:
|
# Example response:
|
||||||
# {
|
# {
|
||||||
# 'user': {
|
# 'user': {
|
||||||
|
@@ -2,7 +2,10 @@ import io
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from urllib.parse import parse_qs
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import bs4
|
import bs4
|
||||||
|
|
||||||
@@ -10,13 +13,13 @@ from onyx.configs.app_configs import (
|
|||||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||||
)
|
)
|
||||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||||
from onyx.connectors.confluence.onyx_confluence import (
|
|
||||||
OnyxConfluence,
|
|
||||||
)
|
|
||||||
from onyx.file_processing.extract_file_text import extract_file_text
|
from onyx.file_processing.extract_file_text import extract_file_text
|
||||||
from onyx.file_processing.html_utils import format_document_soup
|
from onyx.file_processing.html_utils import format_document_soup
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +27,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
|
|||||||
|
|
||||||
|
|
||||||
def get_user_email_from_username__server(
|
def get_user_email_from_username__server(
|
||||||
confluence_client: OnyxConfluence, user_name: str
|
confluence_client: "OnyxConfluence", user_name: str
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
global _USER_EMAIL_CACHE
|
global _USER_EMAIL_CACHE
|
||||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||||
@@ -47,7 +50,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
|||||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||||
|
|
||||||
|
|
||||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||||
"""Get Confluence Display Name based on the account-id or userkey value
|
"""Get Confluence Display Name based on the account-id or userkey value
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -78,7 +81,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def extract_text_from_confluence_html(
|
def extract_text_from_confluence_html(
|
||||||
confluence_client: OnyxConfluence,
|
confluence_client: "OnyxConfluence",
|
||||||
confluence_object: dict[str, Any],
|
confluence_object: dict[str, Any],
|
||||||
fetched_titles: set[str],
|
fetched_titles: set[str],
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -191,7 +194,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def attachment_to_content(
|
def attachment_to_content(
|
||||||
confluence_client: OnyxConfluence,
|
confluence_client: "OnyxConfluence",
|
||||||
attachment: dict[str, Any],
|
attachment: dict[str, Any],
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""If it returns None, assume that we should skip this attachment."""
|
"""If it returns None, assume that we should skip this attachment."""
|
||||||
@@ -279,3 +282,32 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
|||||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||||
|
|
||||||
return datetime_object
|
return datetime_object
|
||||||
|
|
||||||
|
|
||||||
|
def get_single_param_from_url(url: str, param: str) -> str | None:
|
||||||
|
"""Get a parameter from a url"""
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
return parse_qs(parsed_url.query).get(param, [None])[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_start_param_from_url(url: str) -> int:
|
||||||
|
"""Get the start parameter from a url"""
|
||||||
|
start_str = get_single_param_from_url(url, "start")
|
||||||
|
if start_str is None:
|
||||||
|
return 0
|
||||||
|
return int(start_str)
|
||||||
|
|
||||||
|
|
||||||
|
def update_param_in_path(path: str, param: str, value: str) -> str:
|
||||||
|
"""Update a parameter in a path. Path should look something like:
|
||||||
|
|
||||||
|
/api/rest/users?start=0&limit=10
|
||||||
|
"""
|
||||||
|
parsed_url = urlparse(path)
|
||||||
|
query_params = parse_qs(parsed_url.query)
|
||||||
|
query_params[param] = [value]
|
||||||
|
return (
|
||||||
|
path.split("?")[0]
|
||||||
|
+ "?"
|
||||||
|
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user