diff --git a/backend/onyx/connectors/confluence/onyx_confluence.py b/backend/onyx/connectors/confluence/onyx_confluence.py index 8f36534724a8..df28900bcd1a 100644 --- a/backend/onyx/connectors/confluence/onyx_confluence.py +++ b/backend/onyx/connectors/confluence/onyx_confluence.py @@ -11,6 +11,8 @@ from atlassian import Confluence # type:ignore from pydantic import BaseModel from requests import HTTPError +from onyx.connectors.confluence.utils import get_start_param_from_url +from onyx.connectors.confluence.utils import update_param_in_path from onyx.connectors.exceptions import ConnectorValidationError from onyx.utils.logger import setup_logger @@ -161,7 +163,7 @@ class OnyxConfluence(Confluence): ) def _paginate_url( - self, url_suffix: str, limit: int | None = None + self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False ) -> Iterator[dict[str, Any]]: """ This will paginate through the top level query. @@ -236,9 +238,41 @@ class OnyxConfluence(Confluence): raise e # yield the results individually - yield from next_response.get("results", []) + results = cast(list[dict[str, Any]], next_response.get("results", [])) + yield from results - url_suffix = next_response.get("_links", {}).get("next") + old_url_suffix = url_suffix + url_suffix = cast(str, next_response.get("_links", {}).get("next", "")) + + # make sure we don't update the start by more than the amount + # of results we were able to retrieve. The Confluence API has a + # weird behavior where if you pass in a limit that is too large for + # the configured server, it will artificially limit the amount of + # results returned BUT will not apply this to the start parameter. + # This will cause us to miss results. + if url_suffix and "start" in url_suffix: + new_start = get_start_param_from_url(url_suffix) + previous_start = get_start_param_from_url(old_url_suffix) + if new_start - previous_start > len(results): + logger.warning( + f"Start was updated by more than the amount of results " + f"retrieved. This is a bug with Confluence. Start: {new_start}, " + f"Previous Start: {previous_start}, Len Results: {len(results)}." + ) + + # Update the url_suffix to use the adjusted start + adjusted_start = previous_start + len(results) + url_suffix = update_param_in_path( + url_suffix, "start", str(adjusted_start) + ) + + # some APIs don't properly paginate, so we need to manually update the `start` param + if auto_paginate and len(results) > 0: + previous_start = get_start_param_from_url(old_url_suffix) + updated_start = previous_start + len(results) + url_suffix = update_param_in_path( + old_url_suffix, "start", str(updated_start) + ) def paginated_cql_retrieval( self, @@ -298,7 +332,9 @@ class OnyxConfluence(Confluence): url = "rest/api/search/user" expand_string = f"&expand={expand}" if expand else "" url += f"?cql={cql}{expand_string}" - for user_result in self._paginate_url(url, limit): + # endpoint doesn't properly paginate, so we need to manually update the `start` param + # thus the auto_paginate flag + for user_result in self._paginate_url(url, limit, auto_paginate=True): # Example response: # { # 'user': { diff --git a/backend/onyx/connectors/confluence/utils.py b/backend/onyx/connectors/confluence/utils.py index 49fe60a94c53..b77696645b41 100644 --- a/backend/onyx/connectors/confluence/utils.py +++ b/backend/onyx/connectors/confluence/utils.py @@ -2,7 +2,10 @@ import io from datetime import datetime from datetime import timezone from typing import Any +from typing import TYPE_CHECKING +from urllib.parse import parse_qs from urllib.parse import quote +from urllib.parse import urlparse import bs4 @@ -10,13 +13,13 @@ from onyx.configs.app_configs import ( CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD, ) from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD -from onyx.connectors.confluence.onyx_confluence import ( - OnyxConfluence, -) from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_processing.html_utils import format_document_soup from onyx.utils.logger import setup_logger +if TYPE_CHECKING: + from onyx.connectors.confluence.onyx_confluence import OnyxConfluence + logger = setup_logger() @@ -24,7 +27,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {} def get_user_email_from_username__server( - confluence_client: OnyxConfluence, user_name: str + confluence_client: "OnyxConfluence", user_name: str ) -> str | None: global _USER_EMAIL_CACHE if _USER_EMAIL_CACHE.get(user_name) is None: @@ -47,7 +50,7 @@ _USER_NOT_FOUND = "Unknown Confluence User" _USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {} -def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str: +def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str: """Get Confluence Display Name based on the account-id or userkey value Args: @@ -78,7 +81,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str: def extract_text_from_confluence_html( - confluence_client: OnyxConfluence, + confluence_client: "OnyxConfluence", confluence_object: dict[str, Any], fetched_titles: set[str], ) -> str: @@ -191,7 +194,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool: def attachment_to_content( - confluence_client: OnyxConfluence, + confluence_client: "OnyxConfluence", attachment: dict[str, Any], ) -> str | None: """If it returns None, assume that we should skip this attachment.""" @@ -279,3 +282,32 @@ def datetime_from_string(datetime_string: str) -> datetime: datetime_object = datetime_object.astimezone(timezone.utc) return datetime_object + + +def get_single_param_from_url(url: str, param: str) -> str | None: + """Get a parameter from a url""" + parsed_url = urlparse(url) + return parse_qs(parsed_url.query).get(param, [None])[0] + + +def get_start_param_from_url(url: str) -> int: + """Get the start parameter from a url""" + start_str = get_single_param_from_url(url, "start") + if start_str is None: + return 0 + return int(start_str) + + +def update_param_in_path(path: str, param: str, value: str) -> str: + """Update a parameter in a path. Path should look something like: + + /api/rest/users?start=0&limit=10 + """ + parsed_url = urlparse(path) + query_params = parse_qs(parsed_url.query) + query_params[param] = [value] + return ( + path.split("?")[0] + + "?" + + "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items()) + )