mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-30 15:46:19 +02:00
Add support for overriding user list (#4616)
* Add support for overriding user list * Fix * Add typing * pythonify
This commit is contained in:
@@ -351,6 +351,11 @@ NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
== "true"
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Confluence Connector Configs
|
||||
#####
|
||||
|
||||
CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("CONFLUENCE_CONNECTOR_LABELS_TO_SKIP", "").split(
|
||||
@@ -374,6 +379,26 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
# A JSON-formatted array. Each item in the array should have the following structure:
|
||||
# {
|
||||
# "user_id": "1234567890",
|
||||
# "username": "bob",
|
||||
# "display_name": "Bob Fitzgerald",
|
||||
# "email": "bob@example.com",
|
||||
# "type": "known"
|
||||
# }
|
||||
_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = os.environ.get(
|
||||
"CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE", ""
|
||||
)
|
||||
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = cast(
|
||||
list[dict[str, str]] | None,
|
||||
(
|
||||
json.loads(_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE)
|
||||
if _RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Due to breakages in the confluence API, the timezone offset must be specified client side
|
||||
# to match the user's specified timezone.
|
||||
|
||||
|
11
backend/onyx/connectors/confluence/models.py
Normal file
11
backend/onyx/connectors/confluence/models.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ConfluenceUser(BaseModel):
|
||||
user_id: str # accountId in Cloud, userKey in Server
|
||||
username: str | None # Confluence Cloud doesn't give usernames
|
||||
display_name: str
|
||||
# Confluence Data Center doesn't give email back by default,
|
||||
# have to fetch it with a different endpoint
|
||||
email: str | None
|
||||
type: str
|
@@ -12,12 +12,16 @@ from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from requests import HTTPError
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
|
||||
from onyx.connectors.confluence.models import ConfluenceUser
|
||||
from onyx.connectors.confluence.user_profile_override import (
|
||||
process_confluence_user_profiles_override,
|
||||
)
|
||||
from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
@@ -46,16 +50,6 @@ class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConfluenceUser(BaseModel):
|
||||
user_id: str # accountId in Cloud, userKey in Server
|
||||
username: str | None # Confluence Cloud doesn't give usernames
|
||||
display_name: str
|
||||
# Confluence Data Center doesn't give email back by default,
|
||||
# have to fetch it with a different endpoint
|
||||
email: str | None
|
||||
type: str
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 50
|
||||
|
||||
@@ -80,6 +74,11 @@ class OnyxConfluence:
|
||||
url: str,
|
||||
credentials_provider: CredentialsProviderInterface,
|
||||
timeout: int | None = None,
|
||||
# should generally not be passed in, but making it overridable for
|
||||
# easier testing
|
||||
confluence_user_profiles_override: list[dict[str, str]] | None = (
|
||||
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
|
||||
),
|
||||
) -> None:
|
||||
self._is_cloud = is_cloud
|
||||
self._url = url.rstrip("/")
|
||||
@@ -110,6 +109,12 @@ class OnyxConfluence:
|
||||
if timeout:
|
||||
self.shared_base_kwargs["timeout"] = timeout
|
||||
|
||||
self._confluence_user_profiles_override = (
|
||||
process_confluence_user_profiles_override(confluence_user_profiles_override)
|
||||
if confluence_user_profiles_override
|
||||
else None
|
||||
)
|
||||
|
||||
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
|
||||
"""credential_json - the current json credentials
|
||||
Returns a tuple
|
||||
@@ -589,7 +594,14 @@ class OnyxConfluence:
|
||||
It's a seperate endpoint from the content/search endpoint used only for users.
|
||||
Otherwise it's very similar to the content/search endpoint.
|
||||
"""
|
||||
if self.cloud:
|
||||
|
||||
# this is needed since there is a live bug with Confluence Server/Data Center
|
||||
# where not all users are returned by the APIs. This is a workaround needed until
|
||||
# that is patched.
|
||||
if self._confluence_user_profiles_override:
|
||||
yield from self._confluence_user_profiles_override
|
||||
|
||||
elif self._is_cloud:
|
||||
cql = "type=user"
|
||||
url = "rest/api/search/user"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
@@ -680,7 +692,7 @@ class OnyxConfluence:
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch groups.
|
||||
"""
|
||||
user_field = "accountId" if self.cloud else "key"
|
||||
user_field = "accountId" if self._is_cloud else "key"
|
||||
user_value = user_id
|
||||
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
|
||||
user_query = f"{user_field}={quote(user_value)}"
|
||||
|
18
backend/onyx/connectors/confluence/user_profile_override.py
Normal file
18
backend/onyx/connectors/confluence/user_profile_override.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from onyx.connectors.confluence.models import ConfluenceUser
|
||||
|
||||
|
||||
def process_confluence_user_profiles_override(
|
||||
confluence_user_email_override: list[dict[str, str]],
|
||||
) -> list[ConfluenceUser]:
|
||||
return [
|
||||
ConfluenceUser(
|
||||
user_id=override["user_id"],
|
||||
# username is not returned by the Confluence Server API anyways
|
||||
username=override["username"],
|
||||
display_name=override["display_name"],
|
||||
email=override["email"],
|
||||
type=override["type"],
|
||||
)
|
||||
for override in confluence_user_email_override
|
||||
if override is not None
|
||||
]
|
@@ -0,0 +1,121 @@
|
||||
import types
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.confluence.onyx_confluence import ConfluenceUser
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
|
||||
|
||||
class MockCredentialsProvider(CredentialsProviderInterface):
|
||||
def get_tenant_id(self) -> str:
|
||||
return "test_tenant"
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return "test_provider"
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_credentials(self) -> dict[str, str]:
|
||||
return {"confluence_access_token": "test_token"}
|
||||
|
||||
def set_credentials(self, credentials: dict[str, str]) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "MockCredentialsProvider":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: types.TracebackType | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_paginated_cql_user_retrieval_with_overrides() -> None:
|
||||
"""
|
||||
Tests that paginated_cql_user_retrieval yields users from the overrides
|
||||
when provided and is_cloud is False.
|
||||
"""
|
||||
mock_provider = MockCredentialsProvider()
|
||||
overrides = [
|
||||
{
|
||||
"user_id": "override_user_1",
|
||||
"username": "override1",
|
||||
"display_name": "Override User One",
|
||||
"email": "override1@example.com",
|
||||
"type": "override",
|
||||
},
|
||||
{
|
||||
"user_id": "override_user_2",
|
||||
"username": "override2",
|
||||
"display_name": "Override User Two",
|
||||
"email": "override2@example.com",
|
||||
"type": "override",
|
||||
},
|
||||
]
|
||||
expected_users = [ConfluenceUser(**user_data) for user_data in overrides]
|
||||
|
||||
confluence_client = OnyxConfluence(
|
||||
is_cloud=False, # Overrides are primarily for Server/DC
|
||||
url="http://dummy-confluence.com",
|
||||
credentials_provider=mock_provider,
|
||||
confluence_user_profiles_override=overrides,
|
||||
)
|
||||
|
||||
retrieved_users = list(confluence_client.paginated_cql_user_retrieval())
|
||||
|
||||
assert len(retrieved_users) == len(expected_users)
|
||||
# Sort lists by user_id for order-independent comparison
|
||||
retrieved_users.sort(key=lambda u: u.user_id)
|
||||
expected_users.sort(key=lambda u: u.user_id)
|
||||
assert retrieved_users == expected_users
|
||||
|
||||
|
||||
def test_paginated_cql_user_retrieval_no_overrides_server() -> None:
|
||||
"""
|
||||
Tests that paginated_cql_user_retrieval attempts to call the actual
|
||||
API pagination when no overrides are provided for Server/DC.
|
||||
"""
|
||||
mock_provider = MockCredentialsProvider()
|
||||
confluence_client = OnyxConfluence(
|
||||
is_cloud=False,
|
||||
url="http://dummy-confluence.com",
|
||||
credentials_provider=mock_provider,
|
||||
confluence_user_profiles_override=None,
|
||||
)
|
||||
|
||||
# Mock the internal pagination method to check if it's called
|
||||
with patch.object(confluence_client, "_paginate_url") as mock_paginate:
|
||||
mock_paginate.return_value = iter([]) # Return an empty iterator
|
||||
|
||||
list(confluence_client.paginated_cql_user_retrieval())
|
||||
|
||||
mock_paginate.assert_called_once_with("rest/api/user/list", None)
|
||||
|
||||
|
||||
def test_paginated_cql_user_retrieval_no_overrides_cloud() -> None:
|
||||
"""
|
||||
Tests that paginated_cql_user_retrieval attempts to call the actual
|
||||
API pagination when no overrides are provided for Cloud.
|
||||
"""
|
||||
mock_provider = MockCredentialsProvider()
|
||||
confluence_client = OnyxConfluence(
|
||||
is_cloud=True,
|
||||
url="http://dummy-confluence.com", # URL doesn't matter much here due to mocking
|
||||
credentials_provider=mock_provider,
|
||||
confluence_user_profiles_override=None,
|
||||
)
|
||||
|
||||
# Mock the internal pagination method to check if it's called
|
||||
with patch.object(confluence_client, "_paginate_url") as mock_paginate:
|
||||
mock_paginate.return_value = iter([]) # Return an empty iterator
|
||||
|
||||
list(confluence_client.paginated_cql_user_retrieval())
|
||||
|
||||
# Check that the cloud-specific user search URL is called
|
||||
mock_paginate.assert_called_once_with(
|
||||
"rest/api/search/user?cql=type=user", None, auto_paginate=True
|
||||
)
|
Reference in New Issue
Block a user