Add support for overriding user list (#4616)

* Add support for overriding user list

* Fix

* Add typing

* pythonify
This commit is contained in:
Chris Weaver
2025-04-25 15:15:23 -07:00
committed by GitHub
parent 23c6e0f3bf
commit 92b5e1adf4
5 changed files with 200 additions and 13 deletions

View File

@@ -351,6 +351,11 @@ NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
== "true"
)
#####
# Confluence Connector Configs
#####
CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("CONFLUENCE_CONNECTOR_LABELS_TO_SKIP", "").split(
@@ -374,6 +379,26 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
)
# A JSON-formatted array. Each item in the array should have the following structure:
# {
# "user_id": "1234567890",
# "username": "bob",
# "display_name": "Bob Fitzgerald",
# "email": "bob@example.com",
# "type": "known"
# }
_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = os.environ.get(
"CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE", ""
)
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = cast(
list[dict[str, str]] | None,
(
json.loads(_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE)
if _RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
else None
),
)
# Due to breakages in the confluence API, the timezone offset must be specified client side
# to match the user's specified timezone.

View File

@@ -0,0 +1,11 @@
from pydantic import BaseModel
class ConfluenceUser(BaseModel):
user_id: str # accountId in Cloud, userKey in Server
username: str | None # Confluence Cloud doesn't give usernames
display_name: str
# Confluence Data Center doesn't give email back by default,
# have to fetch it with a different endpoint
email: str | None
type: str

View File

@@ -12,12 +12,16 @@ from urllib.parse import quote
import bs4
from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from redis import Redis
from requests import HTTPError
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
from onyx.connectors.confluence.models import ConfluenceUser
from onyx.connectors.confluence.user_profile_override import (
process_confluence_user_profiles_override,
)
from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
@@ -46,16 +50,6 @@ class ConfluenceRateLimitError(Exception):
pass
class ConfluenceUser(BaseModel):
user_id: str # accountId in Cloud, userKey in Server
username: str | None # Confluence Cloud doesn't give usernames
display_name: str
# Confluence Data Center doesn't give email back by default,
# have to fetch it with a different endpoint
email: str | None
type: str
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
@@ -80,6 +74,11 @@ class OnyxConfluence:
url: str,
credentials_provider: CredentialsProviderInterface,
timeout: int | None = None,
# should generally not be passed in, but making it overridable for
# easier testing
confluence_user_profiles_override: list[dict[str, str]] | None = (
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
),
) -> None:
self._is_cloud = is_cloud
self._url = url.rstrip("/")
@@ -110,6 +109,12 @@ class OnyxConfluence:
if timeout:
self.shared_base_kwargs["timeout"] = timeout
self._confluence_user_profiles_override = (
process_confluence_user_profiles_override(confluence_user_profiles_override)
if confluence_user_profiles_override
else None
)
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
"""credential_json - the current json credentials
Returns a tuple
@@ -589,7 +594,14 @@ class OnyxConfluence:
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
if self.cloud:
# this is needed since there is a live bug with Confluence Server/Data Center
# where not all users are returned by the APIs. This is a workaround needed until
# that is patched.
if self._confluence_user_profiles_override:
yield from self._confluence_user_profiles_override
elif self._is_cloud:
cql = "type=user"
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
@@ -680,7 +692,7 @@ class OnyxConfluence:
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
user_field = "accountId" if self.cloud else "key"
user_field = "accountId" if self._is_cloud else "key"
user_value = user_id
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
user_query = f"{user_field}={quote(user_value)}"

View File

@@ -0,0 +1,18 @@
from onyx.connectors.confluence.models import ConfluenceUser
def process_confluence_user_profiles_override(
confluence_user_email_override: list[dict[str, str]],
) -> list[ConfluenceUser]:
return [
ConfluenceUser(
user_id=override["user_id"],
# username is not returned by the Confluence Server API anyways
username=override["username"],
display_name=override["display_name"],
email=override["email"],
type=override["type"],
)
for override in confluence_user_email_override
if override is not None
]

View File

@@ -0,0 +1,121 @@
import types
from unittest.mock import patch
from onyx.connectors.confluence.onyx_confluence import ConfluenceUser
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.interfaces import CredentialsProviderInterface
class MockCredentialsProvider(CredentialsProviderInterface):
def get_tenant_id(self) -> str:
return "test_tenant"
def get_provider_key(self) -> str:
return "test_provider"
def is_dynamic(self) -> bool:
return False
def get_credentials(self) -> dict[str, str]:
return {"confluence_access_token": "test_token"}
def set_credentials(self, credentials: dict[str, str]) -> None:
pass
def __enter__(self) -> "MockCredentialsProvider":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
pass
def test_paginated_cql_user_retrieval_with_overrides() -> None:
"""
Tests that paginated_cql_user_retrieval yields users from the overrides
when provided and is_cloud is False.
"""
mock_provider = MockCredentialsProvider()
overrides = [
{
"user_id": "override_user_1",
"username": "override1",
"display_name": "Override User One",
"email": "override1@example.com",
"type": "override",
},
{
"user_id": "override_user_2",
"username": "override2",
"display_name": "Override User Two",
"email": "override2@example.com",
"type": "override",
},
]
expected_users = [ConfluenceUser(**user_data) for user_data in overrides]
confluence_client = OnyxConfluence(
is_cloud=False, # Overrides are primarily for Server/DC
url="http://dummy-confluence.com",
credentials_provider=mock_provider,
confluence_user_profiles_override=overrides,
)
retrieved_users = list(confluence_client.paginated_cql_user_retrieval())
assert len(retrieved_users) == len(expected_users)
# Sort lists by user_id for order-independent comparison
retrieved_users.sort(key=lambda u: u.user_id)
expected_users.sort(key=lambda u: u.user_id)
assert retrieved_users == expected_users
def test_paginated_cql_user_retrieval_no_overrides_server() -> None:
"""
Tests that paginated_cql_user_retrieval attempts to call the actual
API pagination when no overrides are provided for Server/DC.
"""
mock_provider = MockCredentialsProvider()
confluence_client = OnyxConfluence(
is_cloud=False,
url="http://dummy-confluence.com",
credentials_provider=mock_provider,
confluence_user_profiles_override=None,
)
# Mock the internal pagination method to check if it's called
with patch.object(confluence_client, "_paginate_url") as mock_paginate:
mock_paginate.return_value = iter([]) # Return an empty iterator
list(confluence_client.paginated_cql_user_retrieval())
mock_paginate.assert_called_once_with("rest/api/user/list", None)
def test_paginated_cql_user_retrieval_no_overrides_cloud() -> None:
"""
Tests that paginated_cql_user_retrieval attempts to call the actual
API pagination when no overrides are provided for Cloud.
"""
mock_provider = MockCredentialsProvider()
confluence_client = OnyxConfluence(
is_cloud=True,
url="http://dummy-confluence.com", # URL doesn't matter much here due to mocking
credentials_provider=mock_provider,
confluence_user_profiles_override=None,
)
# Mock the internal pagination method to check if it's called
with patch.object(confluence_client, "_paginate_url") as mock_paginate:
mock_paginate.return_value = iter([]) # Return an empty iterator
list(confluence_client.paginated_cql_user_retrieval())
# Check that the cloud-specific user search URL is called
mock_paginate.assert_called_once_with(
"rest/api/search/user?cql=type=user", None, auto_paginate=True
)