diff --git a/backend/ee/onyx/configs/app_configs.py b/backend/ee/onyx/configs/app_configs.py index b567db38a..3c2b1638c 100644 --- a/backend/ee/onyx/configs/app_configs.py +++ b/backend/ee/onyx/configs/app_configs.py @@ -59,10 +59,14 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key") OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "") OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "") -OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "") -OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "") -OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "") -OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "") +OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get( + "OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", "" +) +OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get( + "OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", "" +) +OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "") +OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "") OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "") OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get( "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" diff --git a/backend/ee/onyx/external_permissions/confluence/doc_sync.py b/backend/ee/onyx/external_permissions/confluence/doc_sync.py index 507a941b8..8ed076a3c 100644 --- a/backend/ee/onyx/external_permissions/confluence/doc_sync.py +++ b/backend/ee/onyx/external_permissions/confluence/doc_sync.py @@ -9,12 +9,16 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR from onyx.access.models import DocExternalAccess from onyx.access.models import ExternalAccess from onyx.connectors.confluence.connector import ConfluenceConnector +from onyx.connectors.confluence.onyx_confluence import ( + get_user_email_from_username__server, +) from onyx.connectors.confluence.onyx_confluence import OnyxConfluence -from onyx.connectors.confluence.utils import get_user_email_from_username__server +from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.connectors.models import SlimDocument from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger +from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @@ -342,7 +346,8 @@ def _fetch_all_page_restrictions( def confluence_doc_sync( - cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None + cc_pair: ConnectorCredentialPair, + callback: IndexingHeartbeatInterface | None, ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres @@ -354,7 +359,11 @@ def confluence_doc_sync( confluence_connector = ConfluenceConnector( **cc_pair.connector.connector_specific_config ) - confluence_connector.load_credentials(cc_pair.credential.credential_json) + + provider = OnyxDBCredentialsProvider( + get_current_tenant_id(), "confluence", cc_pair.credential_id + ) + confluence_connector.set_credentials_provider(provider) is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) diff --git a/backend/ee/onyx/external_permissions/confluence/group_sync.py b/backend/ee/onyx/external_permissions/confluence/group_sync.py index b11d38f63..b1113a5ab 100644 --- a/backend/ee/onyx/external_permissions/confluence/group_sync.py +++ b/backend/ee/onyx/external_permissions/confluence/group_sync.py @@ -1,9 +1,11 @@ from ee.onyx.db.external_perm import ExternalUserGroup from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME from onyx.background.error_logging import emit_background_error -from onyx.connectors.confluence.onyx_confluence import build_confluence_client +from onyx.connectors.confluence.onyx_confluence import ( + get_user_email_from_username__server, +) from onyx.connectors.confluence.onyx_confluence import OnyxConfluence -from onyx.connectors.confluence.utils import get_user_email_from_username__server +from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.db.models import ConnectorCredentialPair from onyx.utils.logger import setup_logger @@ -61,13 +63,27 @@ def _build_group_member_email_map( def confluence_group_sync( + tenant_id: str, cc_pair: ConnectorCredentialPair, ) -> list[ExternalUserGroup]: - confluence_client = build_confluence_client( - credentials=cc_pair.credential.credential_json, - is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False), - wiki_base=cc_pair.connector.connector_specific_config["wiki_base"], - ) + provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id) + is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) + wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"] + url = wiki_base.rstrip("/") + + probe_kwargs = { + "max_backoff_retries": 6, + "max_backoff_seconds": 10, + } + + final_kwargs = { + "max_backoff_retries": 10, + "max_backoff_seconds": 60, + } + + confluence_client = OnyxConfluence(is_cloud, url, provider) + confluence_client._probe_connection(**probe_kwargs) + confluence_client._initialize_connection(**final_kwargs) group_member_email_map = _build_group_member_email_map( confluence_client=confluence_client, diff --git a/backend/ee/onyx/external_permissions/gmail/doc_sync.py b/backend/ee/onyx/external_permissions/gmail/doc_sync.py index a5563d73b..6f1bae674 100644 --- a/backend/ee/onyx/external_permissions/gmail/doc_sync.py +++ b/backend/ee/onyx/external_permissions/gmail/doc_sync.py @@ -32,7 +32,8 @@ def _get_slim_doc_generator( def gmail_doc_sync( - cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None + cc_pair: ConnectorCredentialPair, + callback: IndexingHeartbeatInterface | None, ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres diff --git a/backend/ee/onyx/external_permissions/google_drive/doc_sync.py b/backend/ee/onyx/external_permissions/google_drive/doc_sync.py index 32f8993d0..8d3df7fa8 100644 --- a/backend/ee/onyx/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/onyx/external_permissions/google_drive/doc_sync.py @@ -145,7 +145,8 @@ def _get_permissions_from_slim_doc( def gdrive_doc_sync( - cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None + cc_pair: ConnectorCredentialPair, + callback: IndexingHeartbeatInterface | None, ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres diff --git a/backend/ee/onyx/external_permissions/google_drive/group_sync.py b/backend/ee/onyx/external_permissions/google_drive/group_sync.py index 7d1a27dbe..241aa4780 100644 --- a/backend/ee/onyx/external_permissions/google_drive/group_sync.py +++ b/backend/ee/onyx/external_permissions/google_drive/group_sync.py @@ -119,6 +119,7 @@ def _build_onyx_groups( def gdrive_group_sync( + tenant_id: str, cc_pair: ConnectorCredentialPair, ) -> list[ExternalUserGroup]: # Initialize connector and build credential/service objects diff --git a/backend/ee/onyx/external_permissions/slack/doc_sync.py b/backend/ee/onyx/external_permissions/slack/doc_sync.py index 9522c906d..0ae9b58cc 100644 --- a/backend/ee/onyx/external_permissions/slack/doc_sync.py +++ b/backend/ee/onyx/external_permissions/slack/doc_sync.py @@ -123,7 +123,8 @@ def _fetch_channel_permissions( def slack_doc_sync( - cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None + cc_pair: ConnectorCredentialPair, + callback: IndexingHeartbeatInterface | None, ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres diff --git a/backend/ee/onyx/external_permissions/sync_params.py b/backend/ee/onyx/external_permissions/sync_params.py index 8be6dcb2c..9f8ed9681 100644 --- a/backend/ee/onyx/external_permissions/sync_params.py +++ b/backend/ee/onyx/external_permissions/sync_params.py @@ -28,6 +28,7 @@ DocSyncFuncType = Callable[ GroupSyncFuncType = Callable[ [ + str, ConnectorCredentialPair, ], list[ExternalUserGroup], diff --git a/backend/ee/onyx/main.py b/backend/ee/onyx/main.py index cf6b8191c..7d7278bb2 100644 --- a/backend/ee/onyx/main.py +++ b/backend/ee/onyx/main.py @@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import ( ) from ee.onyx.server.manage.standard_answer import router as standard_answer_router from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware -from ee.onyx.server.oauth import router as oauth_router +from ee.onyx.server.oauth.api import router as oauth_router from ee.onyx.server.query_and_chat.chat_backend import ( router as chat_router, ) @@ -152,4 +152,8 @@ def get_application() -> FastAPI: # environment variable. Used to automate deployment for multiple environments. seed_db() + # for debugging discovered routes + # for route in application.router.routes: + # print(f"Path: {route.path}, Methods: {route.methods}") + return application diff --git a/backend/ee/onyx/server/oauth.py b/backend/ee/onyx/server/oauth.py deleted file mode 100644 index 7204ee1a8..000000000 --- a/backend/ee/onyx/server/oauth.py +++ /dev/null @@ -1,629 +0,0 @@ -import base64 -import json -import uuid -from typing import Any -from typing import cast - -import requests -from fastapi import APIRouter -from fastapi import Depends -from fastapi import HTTPException -from fastapi.responses import JSONResponse -from pydantic import BaseModel -from sqlalchemy.orm import Session - -from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID -from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET -from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID -from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET -from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID -from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET -from onyx.auth.users import current_user -from onyx.configs.app_configs import WEB_DOMAIN -from onyx.configs.constants import DocumentSource -from onyx.connectors.google_utils.google_auth import get_google_oauth_creds -from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials -from onyx.connectors.google_utils.shared_constants import ( - DB_CREDENTIALS_AUTHENTICATION_METHOD, -) -from onyx.connectors.google_utils.shared_constants import ( - DB_CREDENTIALS_DICT_TOKEN_KEY, -) -from onyx.connectors.google_utils.shared_constants import ( - DB_CREDENTIALS_PRIMARY_ADMIN_KEY, -) -from onyx.connectors.google_utils.shared_constants import ( - GoogleOAuthAuthenticationMethod, -) -from onyx.db.credentials import create_credential -from onyx.db.engine import get_session -from onyx.db.models import User -from onyx.redis.redis_pool import get_redis_client -from onyx.server.documents.models import CredentialBase -from onyx.utils.logger import setup_logger -from shared_configs.contextvars import get_current_tenant_id - - -logger = setup_logger() - -router = APIRouter(prefix="/oauth") - - -class SlackOAuth: - # https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth - # Example: https://api.slack.com/authentication/oauth-v2#exchanging - - class OAuthSession(BaseModel): - """Stored in redis to be looked up on callback""" - - email: str - redirect_on_success: str | None # Where to send the user if OAuth flow succeeds - - CLIENT_ID = OAUTH_SLACK_CLIENT_ID - CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET - - TOKEN_URL = "https://slack.com/api/oauth.v2.access" - - # SCOPE is per https://docs.onyx.app/connectors/slack - BOT_SCOPE = ( - "channels:history," - "channels:read," - "groups:history," - "groups:read," - "channels:join," - "im:history," - "users:read," - "users:read.email," - "usergroups:read" - ) - - REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback" - DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" - - @classmethod - def generate_oauth_url(cls, state: str) -> str: - return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) - - @classmethod - def generate_dev_oauth_url(cls, state: str) -> str: - """dev mode workaround for localhost testing - - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https - """ - - return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) - - @classmethod - def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: - url = ( - f"https://slack.com/oauth/v2/authorize" - f"?client_id={cls.CLIENT_ID}" - f"&redirect_uri={redirect_uri}" - f"&scope={cls.BOT_SCOPE}" - f"&state={state}" - ) - return url - - @classmethod - def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: - """Temporary state to store in redis. to be looked up on auth response. - Returns a json string. - """ - session = SlackOAuth.OAuthSession( - email=email, redirect_on_success=redirect_on_success - ) - return session.model_dump_json() - - @classmethod - def parse_session(cls, session_json: str) -> OAuthSession: - session = SlackOAuth.OAuthSession.model_validate_json(session_json) - return session - - -class ConfluenceCloudOAuth: - """work in progress""" - - # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/ - - class OAuthSession(BaseModel): - """Stored in redis to be looked up on callback""" - - email: str - redirect_on_success: str | None # Where to send the user if OAuth flow succeeds - - CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID - CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET - TOKEN_URL = "https://auth.atlassian.com/oauth/token" - - # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/ - CONFLUENCE_OAUTH_SCOPE = ( - "read:confluence-props%20" - "read:confluence-content.all%20" - "read:confluence-content.summary%20" - "read:confluence-content.permission%20" - "read:confluence-user%20" - "read:confluence-groups%20" - "readonly:content.attachment:confluence" - ) - - REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback" - DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" - - # eventually for Confluence Data Center - # oauth_url = ( - # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}" - # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}" - # f"&redirect_uri={redirectme_uri}" - # ) - - @classmethod - def generate_oauth_url(cls, state: str) -> str: - return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) - - @classmethod - def generate_dev_oauth_url(cls, state: str) -> str: - """dev mode workaround for localhost testing - - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https - """ - return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) - - @classmethod - def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: - url = ( - "https://auth.atlassian.com/authorize" - f"?audience=api.atlassian.com" - f"&client_id={cls.CLIENT_ID}" - f"&redirect_uri={redirect_uri}" - f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}" - f"&state={state}" - "&response_type=code" - "&prompt=consent" - ) - return url - - @classmethod - def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: - """Temporary state to store in redis. to be looked up on auth response. - Returns a json string. - """ - session = ConfluenceCloudOAuth.OAuthSession( - email=email, redirect_on_success=redirect_on_success - ) - return session.model_dump_json() - - @classmethod - def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession: - session = SlackOAuth.OAuthSession.model_validate_json(session_json) - return session - - -class GoogleDriveOAuth: - # https://developers.google.com/identity/protocols/oauth2 - # https://developers.google.com/identity/protocols/oauth2/web-server - - class OAuthSession(BaseModel): - """Stored in redis to be looked up on callback""" - - email: str - redirect_on_success: str | None # Where to send the user if OAuth flow succeeds - - CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID - CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET - - TOKEN_URL = "https://oauth2.googleapis.com/token" - - # SCOPE is per https://docs.onyx.app/connectors/google-drive - # TODO: Merge with or use google_utils.GOOGLE_SCOPES - SCOPE = ( - "https://www.googleapis.com/auth/drive.readonly%20" - "https://www.googleapis.com/auth/drive.metadata.readonly%20" - "https://www.googleapis.com/auth/admin.directory.user.readonly%20" - "https://www.googleapis.com/auth/admin.directory.group.readonly" - ) - - REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback" - DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" - - @classmethod - def generate_oauth_url(cls, state: str) -> str: - return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) - - @classmethod - def generate_dev_oauth_url(cls, state: str) -> str: - """dev mode workaround for localhost testing - - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https - """ - - return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) - - @classmethod - def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: - # without prompt=consent, a refresh token is only issued the first time the user approves - url = ( - f"https://accounts.google.com/o/oauth2/v2/auth" - f"?client_id={cls.CLIENT_ID}" - f"&redirect_uri={redirect_uri}" - "&response_type=code" - f"&scope={cls.SCOPE}" - "&access_type=offline" - f"&state={state}" - "&prompt=consent" - ) - return url - - @classmethod - def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: - """Temporary state to store in redis. to be looked up on auth response. - Returns a json string. - """ - session = GoogleDriveOAuth.OAuthSession( - email=email, redirect_on_success=redirect_on_success - ) - return session.model_dump_json() - - @classmethod - def parse_session(cls, session_json: str) -> OAuthSession: - session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json) - return session - - -@router.post("/prepare-authorization-request") -def prepare_authorization_request( - connector: DocumentSource, - redirect_on_success: str | None, - user: User = Depends(current_user), -) -> JSONResponse: - """Used by the frontend to generate the url for the user's browser during auth request. - - Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/ - """ - tenant_id = get_current_tenant_id() - - # create random oauth state param for security and to retrieve user data later - oauth_uuid = uuid.uuid4() - oauth_uuid_str = str(oauth_uuid) - - # urlsafe b64 encode the uuid for the oauth url - oauth_state = ( - base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8") - ) - session: str - - if connector == DocumentSource.SLACK: - oauth_url = SlackOAuth.generate_oauth_url(oauth_state) - session = SlackOAuth.session_dump_json( - email=user.email, redirect_on_success=redirect_on_success - ) - elif connector == DocumentSource.GOOGLE_DRIVE: - oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state) - session = GoogleDriveOAuth.session_dump_json( - email=user.email, redirect_on_success=redirect_on_success - ) - # elif connector == DocumentSource.CONFLUENCE: - # oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state) - # session = ConfluenceCloudOAuth.session_dump_json( - # email=user.email, redirect_on_success=redirect_on_success - # ) - # elif connector == DocumentSource.JIRA: - # oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state) - else: - oauth_url = None - - if not oauth_url: - raise HTTPException( - status_code=404, - detail=f"The document source type {connector} does not have OAuth implemented", - ) - - r = get_redis_client(tenant_id=tenant_id) - - # store important session state to retrieve when the user is redirected back - # 10 min is the max we want an oauth flow to be valid - r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600) - - return JSONResponse(content={"url": oauth_url}) - - -@router.post("/connector/slack/callback") -def handle_slack_oauth_callback( - code: str, - state: str, - user: User = Depends(current_user), - db_session: Session = Depends(get_session), -) -> JSONResponse: - if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET: - raise HTTPException( - status_code=500, - detail="Slack client ID or client secret is not configured.", - ) - - r = get_redis_client() - - # recover the state - padded_state = state + "=" * ( - -len(state) % 4 - ) # Add padding back (Base64 decoding requires padding) - uuid_bytes = base64.urlsafe_b64decode( - padded_state - ) # Decode the Base64 string back to bytes - - # Convert bytes back to a UUID - oauth_uuid = uuid.UUID(bytes=uuid_bytes) - oauth_uuid_str = str(oauth_uuid) - - r_key = f"da_oauth:{oauth_uuid_str}" - - session_json_bytes = cast(bytes, r.get(r_key)) - if not session_json_bytes: - raise HTTPException( - status_code=400, - detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}", - ) - - session_json = session_json_bytes.decode("utf-8") - try: - session = SlackOAuth.parse_session(session_json) - - # Exchange the authorization code for an access token - response = requests.post( - SlackOAuth.TOKEN_URL, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - data={ - "client_id": SlackOAuth.CLIENT_ID, - "client_secret": SlackOAuth.CLIENT_SECRET, - "code": code, - "redirect_uri": SlackOAuth.REDIRECT_URI, - }, - ) - - response_data = response.json() - - if not response_data.get("ok"): - raise HTTPException( - status_code=400, - detail=f"Slack OAuth failed: {response_data.get('error')}", - ) - - # Extract token and team information - access_token: str = response_data.get("access_token") - team_id: str = response_data.get("team", {}).get("id") - authed_user_id: str = response_data.get("authed_user", {}).get("id") - - credential_info = CredentialBase( - credential_json={"slack_bot_token": access_token}, - admin_public=True, - source=DocumentSource.SLACK, - name="Slack OAuth", - ) - - create_credential(credential_info, user, db_session) - except Exception as e: - return JSONResponse( - status_code=500, - content={ - "success": False, - "message": f"An error occurred during Slack OAuth: {str(e)}", - }, - ) - finally: - r.delete(r_key) - - # return the result - return JSONResponse( - content={ - "success": True, - "message": "Slack OAuth completed successfully.", - "team_id": team_id, - "authed_user_id": authed_user_id, - "redirect_on_success": session.redirect_on_success, - } - ) - - -# Work in progress -# @router.post("/connector/confluence/callback") -# def handle_confluence_oauth_callback( -# code: str, -# state: str, -# user: User = Depends(current_user), -# db_session: Session = Depends(get_session), -# tenant_id: str | None = Depends(get_current_tenant_id), -# ) -> JSONResponse: -# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET: -# raise HTTPException( -# status_code=500, -# detail="Confluence client ID or client secret is not configured." -# ) - -# r = get_redis_client(tenant_id=tenant_id) - -# # recover the state -# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding) -# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes - -# # Convert bytes back to a UUID -# oauth_uuid = uuid.UUID(bytes=uuid_bytes) -# oauth_uuid_str = str(oauth_uuid) - -# r_key = f"da_oauth:{oauth_uuid_str}" - -# result = r.get(r_key) -# if not result: -# raise HTTPException( -# status_code=400, -# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}" -# ) - -# try: -# session = ConfluenceCloudOAuth.parse_session(result) - -# # Exchange the authorization code for an access token -# response = requests.post( -# ConfluenceCloudOAuth.TOKEN_URL, -# headers={"Content-Type": "application/x-www-form-urlencoded"}, -# data={ -# "client_id": ConfluenceCloudOAuth.CLIENT_ID, -# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET, -# "code": code, -# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI, -# }, -# ) - -# response_data = response.json() - -# if not response_data.get("ok"): -# raise HTTPException( -# status_code=400, -# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}" -# ) - -# # Extract token and team information -# access_token: str = response_data.get("access_token") -# team_id: str = response_data.get("team", {}).get("id") -# authed_user_id: str = response_data.get("authed_user", {}).get("id") - -# credential_info = CredentialBase( -# credential_json={"slack_bot_token": access_token}, -# admin_public=True, -# source=DocumentSource.CONFLUENCE, -# name="Confluence OAuth", -# ) - -# logger.info(f"Slack access token: {access_token}") - -# credential = create_credential(credential_info, user, db_session) - -# logger.info(f"new_credential_id={credential.id}") -# except Exception as e: -# return JSONResponse( -# status_code=500, -# content={ -# "success": False, -# "message": f"An error occurred during Slack OAuth: {str(e)}", -# }, -# ) -# finally: -# r.delete(r_key) - -# # return the result -# return JSONResponse( -# content={ -# "success": True, -# "message": "Slack OAuth completed successfully.", -# "team_id": team_id, -# "authed_user_id": authed_user_id, -# "redirect_on_success": session.redirect_on_success, -# } -# ) - - -@router.post("/connector/google-drive/callback") -def handle_google_drive_oauth_callback( - code: str, - state: str, - user: User = Depends(current_user), - db_session: Session = Depends(get_session), -) -> JSONResponse: - if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET: - raise HTTPException( - status_code=500, - detail="Google Drive client ID or client secret is not configured.", - ) - - r = get_redis_client() - - # recover the state - padded_state = state + "=" * ( - -len(state) % 4 - ) # Add padding back (Base64 decoding requires padding) - uuid_bytes = base64.urlsafe_b64decode( - padded_state - ) # Decode the Base64 string back to bytes - - # Convert bytes back to a UUID - oauth_uuid = uuid.UUID(bytes=uuid_bytes) - oauth_uuid_str = str(oauth_uuid) - - r_key = f"da_oauth:{oauth_uuid_str}" - - session_json_bytes = cast(bytes, r.get(r_key)) - if not session_json_bytes: - raise HTTPException( - status_code=400, - detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}", - ) - - session_json = session_json_bytes.decode("utf-8") - session: GoogleDriveOAuth.OAuthSession - try: - session = GoogleDriveOAuth.parse_session(session_json) - - # Exchange the authorization code for an access token - response = requests.post( - GoogleDriveOAuth.TOKEN_URL, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - data={ - "client_id": GoogleDriveOAuth.CLIENT_ID, - "client_secret": GoogleDriveOAuth.CLIENT_SECRET, - "code": code, - "redirect_uri": GoogleDriveOAuth.REDIRECT_URI, - "grant_type": "authorization_code", - }, - ) - - response.raise_for_status() - - authorization_response: dict[str, Any] = response.json() - - # the connector wants us to store the json in its authorized_user_info format - # returned from OAuthCredentials.get_authorized_user_info(). - # So refresh immediately via get_google_oauth_creds with the params filled in - # from fields in authorization_response to get the json we need - authorized_user_info = {} - authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID - authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET - authorized_user_info["refresh_token"] = authorization_response["refresh_token"] - - token_json_str = json.dumps(authorized_user_info) - oauth_creds = get_google_oauth_creds( - token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE - ) - if not oauth_creds: - raise RuntimeError("get_google_oauth_creds returned None.") - - # save off the credentials - oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds) - - credential_dict: dict[str, str] = {} - credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str - credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email - credential_dict[ - DB_CREDENTIALS_AUTHENTICATION_METHOD - ] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value - - credential_info = CredentialBase( - credential_json=credential_dict, - admin_public=True, - source=DocumentSource.GOOGLE_DRIVE, - name="OAuth (interactive)", - ) - - create_credential(credential_info, user, db_session) - except Exception as e: - return JSONResponse( - status_code=500, - content={ - "success": False, - "message": f"An error occurred during Google Drive OAuth: {str(e)}", - }, - ) - finally: - r.delete(r_key) - - # return the result - return JSONResponse( - content={ - "success": True, - "message": "Google Drive OAuth completed successfully.", - "redirect_on_success": session.redirect_on_success, - } - ) diff --git a/backend/ee/onyx/server/oauth/api.py b/backend/ee/onyx/server/oauth/api.py new file mode 100644 index 000000000..f9eb4a751 --- /dev/null +++ b/backend/ee/onyx/server/oauth/api.py @@ -0,0 +1,91 @@ +import base64 +import uuid + +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import JSONResponse + +from ee.onyx.server.oauth.api_router import router +from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth +from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth +from ee.onyx.server.oauth.slack import SlackOAuth +from onyx.auth.users import current_admin_user +from onyx.configs.app_configs import DEV_MODE +from onyx.configs.constants import DocumentSource +from onyx.db.engine import get_current_tenant_id +from onyx.db.models import User +from onyx.redis.redis_pool import get_redis_client +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +@router.post("/prepare-authorization-request") +def prepare_authorization_request( + connector: DocumentSource, + redirect_on_success: str | None, + user: User = Depends(current_admin_user), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """Used by the frontend to generate the url for the user's browser during auth request. + + Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/ + """ + + # create random oauth state param for security and to retrieve user data later + oauth_uuid = uuid.uuid4() + oauth_uuid_str = str(oauth_uuid) + + # urlsafe b64 encode the uuid for the oauth url + oauth_state = ( + base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8") + ) + + session: str | None = None + if connector == DocumentSource.SLACK: + if not DEV_MODE: + oauth_url = SlackOAuth.generate_oauth_url(oauth_state) + else: + oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state) + + session = SlackOAuth.session_dump_json( + email=user.email, redirect_on_success=redirect_on_success + ) + elif connector == DocumentSource.CONFLUENCE: + if not DEV_MODE: + oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state) + else: + oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state) + session = ConfluenceCloudOAuth.session_dump_json( + email=user.email, redirect_on_success=redirect_on_success + ) + elif connector == DocumentSource.GOOGLE_DRIVE: + if not DEV_MODE: + oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state) + else: + oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state) + session = GoogleDriveOAuth.session_dump_json( + email=user.email, redirect_on_success=redirect_on_success + ) + else: + oauth_url = None + + if not oauth_url: + raise HTTPException( + status_code=404, + detail=f"The document source type {connector} does not have OAuth implemented", + ) + + if not session: + raise HTTPException( + status_code=500, + detail=f"The document source type {connector} failed to generate an OAuth session.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # store important session state to retrieve when the user is redirected back + # 10 min is the max we want an oauth flow to be valid + r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600) + + return JSONResponse(content={"url": oauth_url}) diff --git a/backend/ee/onyx/server/oauth/api_router.py b/backend/ee/onyx/server/oauth/api_router.py new file mode 100644 index 000000000..b99ec55a4 --- /dev/null +++ b/backend/ee/onyx/server/oauth/api_router.py @@ -0,0 +1,3 @@ +from fastapi import APIRouter + +router: APIRouter = APIRouter(prefix="/oauth") diff --git a/backend/ee/onyx/server/oauth/confluence_cloud.py b/backend/ee/onyx/server/oauth/confluence_cloud.py new file mode 100644 index 000000000..22fd23f98 --- /dev/null +++ b/backend/ee/onyx/server/oauth/confluence_cloud.py @@ -0,0 +1,361 @@ +import base64 +import uuid +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from typing import Any +from typing import cast + +import requests +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from pydantic import ValidationError +from sqlalchemy.orm import Session + +from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET +from ee.onyx.server.oauth.api_router import router +from onyx.auth.users import current_admin_user +from onyx.configs.app_configs import DEV_MODE +from onyx.configs.app_configs import WEB_DOMAIN +from onyx.configs.constants import DocumentSource +from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL +from onyx.db.credentials import create_credential +from onyx.db.credentials import fetch_credential_by_id_for_user +from onyx.db.credentials import update_credential_json +from onyx.db.engine import get_current_tenant_id +from onyx.db.engine import get_session +from onyx.db.models import User +from onyx.redis.redis_pool import get_redis_client +from onyx.server.documents.models import CredentialBase +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +class ConfluenceCloudOAuth: + # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/ + + class OAuthSession(BaseModel): + """Stored in redis to be looked up on callback""" + + email: str + redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + + class TokenResponse(BaseModel): + access_token: str + expires_in: int + token_type: str + refresh_token: str + scope: str + + class AccessibleResources(BaseModel): + id: str + name: str + url: str + scopes: list[str] + avatarUrl: str + + CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID + CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET + TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL + + ACCESSIBLE_RESOURCE_URL = ( + "https://api.atlassian.com/oauth/token/accessible-resources" + ) + + # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/ + CONFLUENCE_OAUTH_SCOPE = ( + # classic scope + "read:confluence-space.summary%20" + "read:confluence-props%20" + "read:confluence-content.all%20" + "read:confluence-content.summary%20" + "read:confluence-content.permission%20" + "read:confluence-user%20" + "read:confluence-groups%20" + "readonly:content.attachment:confluence%20" + "search:confluence%20" + # granular scope + "read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api + "offline_access" + ) + + REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback" + DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + + # eventually for Confluence Data Center + # oauth_url = ( + # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}" + # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}" + # f"&redirect_uri={redirectme_uri}" + # ) + + @classmethod + def generate_oauth_url(cls, state: str) -> str: + return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) + + @classmethod + def generate_dev_oauth_url(cls, state: str) -> str: + """dev mode workaround for localhost testing + - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https + """ + return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) + + @classmethod + def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: + # https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code + + url = ( + "https://auth.atlassian.com/authorize" + f"?audience=api.atlassian.com" + f"&client_id={cls.CLIENT_ID}" + f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}" + f"&redirect_uri={redirect_uri}" + f"&state={state}" + "&response_type=code" + "&prompt=consent" + ) + return url + + @classmethod + def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: + """Temporary state to store in redis. to be looked up on auth response. + Returns a json string. + """ + session = ConfluenceCloudOAuth.OAuthSession( + email=email, redirect_on_success=redirect_on_success + ) + return session.model_dump_json() + + @classmethod + def parse_session(cls, session_json: str) -> OAuthSession: + session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json) + return session + + @classmethod + def generate_finalize_url(cls, credential_id: int) -> str: + return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}" + + +@router.post("/connector/confluence/callback") +def confluence_oauth_callback( + code: str, + state: str, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """Handles the backend logic for the frontend page that the user is redirected to + after visiting the oauth authorization url.""" + + if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET: + raise HTTPException( + status_code=500, + detail="Confluence Cloud client ID or client secret is not configured.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # recover the state + padded_state = state + "=" * ( + -len(state) % 4 + ) # Add padding back (Base64 decoding requires padding) + uuid_bytes = base64.urlsafe_b64decode( + padded_state + ) # Decode the Base64 string back to bytes + + # Convert bytes back to a UUID + oauth_uuid = uuid.UUID(bytes=uuid_bytes) + oauth_uuid_str = str(oauth_uuid) + + r_key = f"da_oauth:{oauth_uuid_str}" + + session_json_bytes = cast(bytes, r.get(r_key)) + if not session_json_bytes: + raise HTTPException( + status_code=400, + detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}", + ) + + session_json = session_json_bytes.decode("utf-8") + try: + session = ConfluenceCloudOAuth.parse_session(session_json) + + if not DEV_MODE: + redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI + else: + redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI + + # Exchange the authorization code for an access token + response = requests.post( + ConfluenceCloudOAuth.TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": ConfluenceCloudOAuth.CLIENT_ID, + "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET, + "code": code, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + + token_response: ConfluenceCloudOAuth.TokenResponse | None = None + + try: + token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json( + response.text + ) + except Exception: + raise RuntimeError( + "Confluence Cloud OAuth failed during code/token exchange." + ) + + now = datetime.now(timezone.utc) + expires_at = now + timedelta(seconds=token_response.expires_in) + + credential_info = CredentialBase( + credential_json={ + "confluence_access_token": token_response.access_token, + "confluence_refresh_token": token_response.refresh_token, + "created_at": now.isoformat(), + "expires_at": expires_at.isoformat(), + "expires_in": token_response.expires_in, + "scope": token_response.scope, + }, + admin_public=True, + source=DocumentSource.CONFLUENCE, + name="Confluence Cloud OAuth", + ) + + credential = create_credential(credential_info, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Confluence Cloud OAuth: {str(e)}", + }, + ) + finally: + r.delete(r_key) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Confluence Cloud OAuth completed successfully.", + "finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id), + "redirect_on_success": session.redirect_on_success, + } + ) + + +@router.get("/connector/confluence/accessible-resources") +def confluence_oauth_accessible_resources( + credential_id: int, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """Atlassian's API is weird and does not supply us with enough info to be in a + usable state after authorizing. All API's require a cloud id. We have to list + the accessible resources/sites and let the user choose which site to use.""" + + credential = fetch_credential_by_id_for_user(credential_id, user, db_session) + if not credential: + raise HTTPException(400, f"Credential {credential_id} not found.") + + credential_dict = credential.credential_json + access_token = credential_dict["confluence_access_token"] + + try: + # Exchange the authorization code for an access token + response = requests.get( + ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL, + headers={ + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + }, + ) + + response.raise_for_status() + accessible_resources_data = response.json() + + # Validate the list of AccessibleResources + try: + accessible_resources = [ + ConfluenceCloudOAuth.AccessibleResources(**resource) + for resource in accessible_resources_data + ] + except ValidationError as e: + raise RuntimeError(f"Failed to parse accessible resources: {e}") + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}", + }, + ) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Confluence Cloud get accessible resources completed successfully.", + "accessible_resources": [ + resource.model_dump() for resource in accessible_resources + ], + } + ) + + +@router.post("/connector/confluence/finalize") +def confluence_oauth_finalize( + credential_id: int, + cloud_id: str, + cloud_name: str, + cloud_url: str, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """Saves the info for the selected cloud site to the credential. + This is the final step in the confluence oauth flow where after the traditional + OAuth process, the user has to select a site to associate with the credentials. + After this, the credential is usable.""" + + credential = fetch_credential_by_id_for_user(credential_id, user, db_session) + if not credential: + raise HTTPException( + status_code=400, + detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.", + ) + + new_credential_json: dict[str, Any] = dict(credential.credential_json) + new_credential_json["cloud_id"] = cloud_id + new_credential_json["cloud_name"] = cloud_name + new_credential_json["wiki_base"] = cloud_url + + try: + update_credential_json(credential_id, new_credential_json, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Confluence Cloud OAuth: {str(e)}", + }, + ) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Confluence Cloud OAuth finalized successfully.", + "redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence", + } + ) diff --git a/backend/ee/onyx/server/oauth/google_drive.py b/backend/ee/onyx/server/oauth/google_drive.py new file mode 100644 index 000000000..68f224c76 --- /dev/null +++ b/backend/ee/onyx/server/oauth/google_drive.py @@ -0,0 +1,229 @@ +import base64 +import json +import uuid +from typing import Any +from typing import cast + +import requests +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET +from ee.onyx.server.oauth.api_router import router +from onyx.auth.users import current_admin_user +from onyx.configs.app_configs import DEV_MODE +from onyx.configs.app_configs import WEB_DOMAIN +from onyx.configs.constants import DocumentSource +from onyx.connectors.google_utils.google_auth import get_google_oauth_creds +from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials +from onyx.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) +from onyx.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_DICT_TOKEN_KEY, +) +from onyx.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_PRIMARY_ADMIN_KEY, +) +from onyx.connectors.google_utils.shared_constants import ( + GoogleOAuthAuthenticationMethod, +) +from onyx.db.credentials import create_credential +from onyx.db.engine import get_current_tenant_id +from onyx.db.engine import get_session +from onyx.db.models import User +from onyx.redis.redis_pool import get_redis_client +from onyx.server.documents.models import CredentialBase + + +class GoogleDriveOAuth: + # https://developers.google.com/identity/protocols/oauth2 + # https://developers.google.com/identity/protocols/oauth2/web-server + + class OAuthSession(BaseModel): + """Stored in redis to be looked up on callback""" + + email: str + redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + + CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID + CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET + + TOKEN_URL = "https://oauth2.googleapis.com/token" + + # SCOPE is per https://docs.danswer.dev/connectors/google-drive + # TODO: Merge with or use google_utils.GOOGLE_SCOPES + SCOPE = ( + "https://www.googleapis.com/auth/drive.readonly%20" + "https://www.googleapis.com/auth/drive.metadata.readonly%20" + "https://www.googleapis.com/auth/admin.directory.user.readonly%20" + "https://www.googleapis.com/auth/admin.directory.group.readonly" + ) + + REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback" + DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + + @classmethod + def generate_oauth_url(cls, state: str) -> str: + return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) + + @classmethod + def generate_dev_oauth_url(cls, state: str) -> str: + """dev mode workaround for localhost testing + - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https + """ + + return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) + + @classmethod + def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: + # without prompt=consent, a refresh token is only issued the first time the user approves + url = ( + f"https://accounts.google.com/o/oauth2/v2/auth" + f"?client_id={cls.CLIENT_ID}" + f"&redirect_uri={redirect_uri}" + "&response_type=code" + f"&scope={cls.SCOPE}" + "&access_type=offline" + f"&state={state}" + "&prompt=consent" + ) + return url + + @classmethod + def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: + """Temporary state to store in redis. to be looked up on auth response. + Returns a json string. + """ + session = GoogleDriveOAuth.OAuthSession( + email=email, redirect_on_success=redirect_on_success + ) + return session.model_dump_json() + + @classmethod + def parse_session(cls, session_json: str) -> OAuthSession: + session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json) + return session + + +@router.post("/connector/google-drive/callback") +def handle_google_drive_oauth_callback( + code: str, + state: str, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET: + raise HTTPException( + status_code=500, + detail="Google Drive client ID or client secret is not configured.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # recover the state + padded_state = state + "=" * ( + -len(state) % 4 + ) # Add padding back (Base64 decoding requires padding) + uuid_bytes = base64.urlsafe_b64decode( + padded_state + ) # Decode the Base64 string back to bytes + + # Convert bytes back to a UUID + oauth_uuid = uuid.UUID(bytes=uuid_bytes) + oauth_uuid_str = str(oauth_uuid) + + r_key = f"da_oauth:{oauth_uuid_str}" + + session_json_bytes = cast(bytes, r.get(r_key)) + if not session_json_bytes: + raise HTTPException( + status_code=400, + detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}", + ) + + session_json = session_json_bytes.decode("utf-8") + try: + session = GoogleDriveOAuth.parse_session(session_json) + + if not DEV_MODE: + redirect_uri = GoogleDriveOAuth.REDIRECT_URI + else: + redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI + + # Exchange the authorization code for an access token + response = requests.post( + GoogleDriveOAuth.TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": GoogleDriveOAuth.CLIENT_ID, + "client_secret": GoogleDriveOAuth.CLIENT_SECRET, + "code": code, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + + response.raise_for_status() + + authorization_response: dict[str, Any] = response.json() + + # the connector wants us to store the json in its authorized_user_info format + # returned from OAuthCredentials.get_authorized_user_info(). + # So refresh immediately via get_google_oauth_creds with the params filled in + # from fields in authorization_response to get the json we need + authorized_user_info = {} + authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID + authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET + authorized_user_info["refresh_token"] = authorization_response["refresh_token"] + + token_json_str = json.dumps(authorized_user_info) + oauth_creds = get_google_oauth_creds( + token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE + ) + if not oauth_creds: + raise RuntimeError("get_google_oauth_creds returned None.") + + # save off the credentials + oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds) + + credential_dict: dict[str, str] = {} + credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str + credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email + credential_dict[ + DB_CREDENTIALS_AUTHENTICATION_METHOD + ] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value + + credential_info = CredentialBase( + credential_json=credential_dict, + admin_public=True, + source=DocumentSource.GOOGLE_DRIVE, + name="OAuth (interactive)", + ) + + create_credential(credential_info, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Google Drive OAuth: {str(e)}", + }, + ) + finally: + r.delete(r_key) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Google Drive OAuth completed successfully.", + "finalize_url": None, + "redirect_on_success": session.redirect_on_success, + } + ) diff --git a/backend/ee/onyx/server/oauth/slack.py b/backend/ee/onyx/server/oauth/slack.py new file mode 100644 index 000000000..e8c5c3063 --- /dev/null +++ b/backend/ee/onyx/server/oauth/slack.py @@ -0,0 +1,197 @@ +import base64 +import uuid +from typing import cast + +import requests +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET +from ee.onyx.server.oauth.api_router import router +from onyx.auth.users import current_admin_user +from onyx.configs.app_configs import DEV_MODE +from onyx.configs.app_configs import WEB_DOMAIN +from onyx.configs.constants import DocumentSource +from onyx.db.credentials import create_credential +from onyx.db.engine import get_current_tenant_id +from onyx.db.engine import get_session +from onyx.db.models import User +from onyx.redis.redis_pool import get_redis_client +from onyx.server.documents.models import CredentialBase + + +class SlackOAuth: + # https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth + # Example: https://api.slack.com/authentication/oauth-v2#exchanging + + class OAuthSession(BaseModel): + """Stored in redis to be looked up on callback""" + + email: str + redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + + CLIENT_ID = OAUTH_SLACK_CLIENT_ID + CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET + + TOKEN_URL = "https://slack.com/api/oauth.v2.access" + + # SCOPE is per https://docs.danswer.dev/connectors/slack + BOT_SCOPE = ( + "channels:history," + "channels:read," + "groups:history," + "groups:read," + "channels:join," + "im:history," + "users:read," + "users:read.email," + "usergroups:read" + ) + + REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback" + DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + + @classmethod + def generate_oauth_url(cls, state: str) -> str: + return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) + + @classmethod + def generate_dev_oauth_url(cls, state: str) -> str: + """dev mode workaround for localhost testing + - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https + """ + + return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) + + @classmethod + def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: + url = ( + f"https://slack.com/oauth/v2/authorize" + f"?client_id={cls.CLIENT_ID}" + f"&redirect_uri={redirect_uri}" + f"&scope={cls.BOT_SCOPE}" + f"&state={state}" + ) + return url + + @classmethod + def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: + """Temporary state to store in redis. to be looked up on auth response. + Returns a json string. + """ + session = SlackOAuth.OAuthSession( + email=email, redirect_on_success=redirect_on_success + ) + return session.model_dump_json() + + @classmethod + def parse_session(cls, session_json: str) -> OAuthSession: + session = SlackOAuth.OAuthSession.model_validate_json(session_json) + return session + + +@router.post("/connector/slack/callback") +def handle_slack_oauth_callback( + code: str, + state: str, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET: + raise HTTPException( + status_code=500, + detail="Slack client ID or client secret is not configured.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # recover the state + padded_state = state + "=" * ( + -len(state) % 4 + ) # Add padding back (Base64 decoding requires padding) + uuid_bytes = base64.urlsafe_b64decode( + padded_state + ) # Decode the Base64 string back to bytes + + # Convert bytes back to a UUID + oauth_uuid = uuid.UUID(bytes=uuid_bytes) + oauth_uuid_str = str(oauth_uuid) + + r_key = f"da_oauth:{oauth_uuid_str}" + + session_json_bytes = cast(bytes, r.get(r_key)) + if not session_json_bytes: + raise HTTPException( + status_code=400, + detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}", + ) + + session_json = session_json_bytes.decode("utf-8") + try: + session = SlackOAuth.parse_session(session_json) + + if not DEV_MODE: + redirect_uri = SlackOAuth.REDIRECT_URI + else: + redirect_uri = SlackOAuth.DEV_REDIRECT_URI + + # Exchange the authorization code for an access token + response = requests.post( + SlackOAuth.TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": SlackOAuth.CLIENT_ID, + "client_secret": SlackOAuth.CLIENT_SECRET, + "code": code, + "redirect_uri": redirect_uri, + }, + ) + + response_data = response.json() + + if not response_data.get("ok"): + raise HTTPException( + status_code=400, + detail=f"Slack OAuth failed: {response_data.get('error')}", + ) + + # Extract token and team information + access_token: str = response_data.get("access_token") + team_id: str = response_data.get("team", {}).get("id") + authed_user_id: str = response_data.get("authed_user", {}).get("id") + + credential_info = CredentialBase( + credential_json={"slack_bot_token": access_token}, + admin_public=True, + source=DocumentSource.SLACK, + name="Slack OAuth", + ) + + create_credential(credential_info, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Slack OAuth: {str(e)}", + }, + ) + finally: + r.delete(r_key) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Slack OAuth completed successfully.", + "finalize_url": None, + "redirect_on_success": session.redirect_on_success, + "team_id": team_id, + "authed_user_id": authed_user_id, + } + ) diff --git a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py index 6aa257305..1599e0ae1 100644 --- a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py @@ -423,7 +423,7 @@ def connector_external_group_sync_generator_task( ) external_user_groups: list[ExternalUserGroup] = [] try: - external_user_groups = ext_group_sync_func(cc_pair) + external_user_groups = ext_group_sync_func(tenant_id, cc_pair) except ConnectorValidationError as e: msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}" update_connector_credential_pair( diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index d569d6dfd..fd63c0d9b 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -93,10 +93,11 @@ def _get_connector_runner( runnable_connector.validate_connector_settings() except Exception as e: - logger.exception(f"Unable to instantiate connector due to {e}") - + logger.exception("Unable to instantiate connector.") # since we failed to even instantiate the connector, we pause the CCPair since - # it will never succeed. Sometimes there are cases where the connector will + # it will never succeed + + # Sometimes there are cases where the connector will # intermittently fail to initialize in which case we should pass in # leave_connector_active=True to allow it to continue. # For example, if there is nightly maintenance on a Confluence Server instance, diff --git a/backend/onyx/connectors/confluence/connector.py b/backend/onyx/connectors/confluence/connector.py index 43d00c42b..29e279014 100644 --- a/backend/onyx/connectors/confluence/connector.py +++ b/backend/onyx/connectors/confluence/connector.py @@ -11,17 +11,20 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource -from onyx.connectors.confluence.onyx_confluence import build_confluence_client +from onyx.connectors.confluence.onyx_confluence import attachment_to_content +from onyx.connectors.confluence.onyx_confluence import ( + extract_text_from_confluence_html, +) from onyx.connectors.confluence.onyx_confluence import OnyxConfluence -from onyx.connectors.confluence.utils import attachment_to_content from onyx.connectors.confluence.utils import build_confluence_document_id from onyx.connectors.confluence.utils import datetime_from_string -from onyx.connectors.confluence.utils import extract_text_from_confluence_html from onyx.connectors.confluence.utils import validate_attachment_filetype from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedError +from onyx.connectors.interfaces import CredentialsConnector +from onyx.connectors.interfaces import CredentialsProviderInterface from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import LoadConnector @@ -83,7 +86,9 @@ _FULL_EXTENSION_FILTER_STRING = "".join( ) -class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): +class ConfluenceConnector( + LoadConnector, PollConnector, SlimConnector, CredentialsConnector +): def __init__( self, wiki_base: str, @@ -102,7 +107,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): ) -> None: self.batch_size = batch_size self.continue_on_failure = continue_on_failure - self._confluence_client: OnyxConfluence | None = None self.is_cloud = is_cloud # Remove trailing slash from wiki_base if present @@ -137,6 +141,19 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): self.cql_label_filter = f" and label not in ({comma_separated_labels})" self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset)) + self.credentials_provider: CredentialsProviderInterface | None = None + + self.probe_kwargs = { + "max_backoff_retries": 6, + "max_backoff_seconds": 10, + } + + self.final_kwargs = { + "max_backoff_retries": 10, + "max_backoff_seconds": 60, + } + + self._confluence_client: OnyxConfluence | None = None @property def confluence_client(self) -> OnyxConfluence: @@ -144,15 +161,22 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): raise ConnectorMissingCredentialError("Confluence") return self._confluence_client - def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: - # see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py - # for a list of other hidden constructor args - self._confluence_client = build_confluence_client( - credentials=credentials, - is_cloud=self.is_cloud, - wiki_base=self.wiki_base, + def set_credentials_provider( + self, credentials_provider: CredentialsProviderInterface + ) -> None: + self.credentials_provider = credentials_provider + + # raises exception if there's a problem + confluence_client = OnyxConfluence( + self.is_cloud, self.wiki_base, credentials_provider ) - return None + confluence_client._probe_connection(**self.probe_kwargs) + confluence_client._initialize_connection(**self.final_kwargs) + + self._confluence_client = confluence_client + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + raise NotImplementedError("Use set_credentials_provider with this connector.") def _construct_page_query( self, @@ -202,12 +226,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): return comment_string def _convert_object_to_document( - self, confluence_object: dict[str, Any] + self, + confluence_object: dict[str, Any], + parent_content_id: str | None = None, ) -> Document | None: """ Takes in a confluence object, extracts all metadata, and converts it into a document. If its a page, it extracts the text, adds the comments for the document text. If its an attachment, it just downloads the attachment and converts that into a document. + + parent_content_id: if the object is an attachment, specifies the content id that + the attachment is attached to """ # The url and the id are the same object_url = build_confluence_document_id( @@ -226,7 +255,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): object_text += self._get_comment_string_for_page_id(confluence_object["id"]) elif confluence_object["type"] == "attachment": object_text = attachment_to_content( - confluence_client=self.confluence_client, attachment=confluence_object + confluence_client=self.confluence_client, + attachment=confluence_object, + parent_content_id=parent_content_id, ) if object_text is None: @@ -302,7 +333,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): cql=attachment_query, expand=",".join(_ATTACHMENT_EXPANSION_FIELDS), ): - doc = self._convert_object_to_document(attachment) + doc = self._convert_object_to_document(attachment, confluence_page_id) if doc is not None: doc_batch.append(doc) if len(doc_batch) >= self.batch_size: diff --git a/backend/onyx/connectors/confluence/onyx_confluence.py b/backend/onyx/connectors/confluence/onyx_confluence.py index df28900bc..147ed82c6 100644 --- a/backend/onyx/connectors/confluence/onyx_confluence.py +++ b/backend/onyx/connectors/confluence/onyx_confluence.py @@ -1,19 +1,37 @@ -import math +import io +import json import time from collections.abc import Callable from collections.abc import Iterator +from datetime import datetime +from datetime import timedelta +from datetime import timezone from typing import Any from typing import cast from typing import TypeVar from urllib.parse import quote +import bs4 from atlassian import Confluence # type:ignore from pydantic import BaseModel +from redis import Redis from requests import HTTPError +from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET +from onyx.configs.app_configs import ( + CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD, +) +from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD +from onyx.connectors.confluence.utils import _handle_http_error +from onyx.connectors.confluence.utils import confluence_refresh_tokens from onyx.connectors.confluence.utils import get_start_param_from_url from onyx.connectors.confluence.utils import update_param_in_path -from onyx.connectors.exceptions import ConnectorValidationError +from onyx.connectors.confluence.utils import validate_attachment_filetype +from onyx.connectors.interfaces import CredentialsProviderInterface +from onyx.file_processing.extract_file_text import extract_file_text +from onyx.file_processing.html_utils import format_document_soup +from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger logger = setup_logger() @@ -22,12 +40,14 @@ logger = setup_logger() F = TypeVar("F", bound=Callable[..., Any]) -RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() - # https://jira.atlassian.com/browse/CONFCLOUD-76433 _PROBLEMATIC_EXPANSIONS = "body.storage.value" _REPLACEMENT_EXPANSIONS = "body.view.value" +_USER_NOT_FOUND = "Unknown Confluence User" +_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {} +_USER_EMAIL_CACHE: dict[str, str | None] = {} + class ConfluenceRateLimitError(Exception): pass @@ -43,124 +63,349 @@ class ConfluenceUser(BaseModel): type: str -def _handle_http_error(e: HTTPError, attempt: int) -> int: - MIN_DELAY = 2 - MAX_DELAY = 60 - STARTING_DELAY = 5 - BACKOFF = 2 - - # Check if the response or headers are None to avoid potential AttributeError - if e.response is None or e.response.headers is None: - logger.warning("HTTPError with `None` as response or as headers") - raise e - - if ( - e.response.status_code != 429 - and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower() - ): - raise e - - retry_after = None - - retry_after_header = e.response.headers.get("Retry-After") - if retry_after_header is not None: - try: - retry_after = int(retry_after_header) - if retry_after > MAX_DELAY: - logger.warning( - f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..." - ) - retry_after = MAX_DELAY - if retry_after < MIN_DELAY: - retry_after = MIN_DELAY - except ValueError: - pass - - if retry_after is not None: - logger.warning( - f"Rate limiting with retry header. Retrying after {retry_after} seconds..." - ) - delay = retry_after - else: - logger.warning( - "Rate limiting without retry header. Retrying with exponential backoff..." - ) - delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY) - - delay_until = math.ceil(time.monotonic() + delay) - return delay_until - - -# https://developer.atlassian.com/cloud/confluence/rate-limiting/ -# this uses the native rate limiting option provided by the -# confluence client and otherwise applies a simpler set of error handling -def handle_confluence_rate_limit(confluence_call: F) -> F: - def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: - MAX_RETRIES = 5 - - TIMEOUT = 600 - timeout_at = time.monotonic() + TIMEOUT - - for attempt in range(MAX_RETRIES): - if time.monotonic() > timeout_at: - raise TimeoutError( - f"Confluence call attempts took longer than {TIMEOUT} seconds." - ) - - try: - # we're relying more on the client to rate limit itself - # and applying our own retries in a more specific set of circumstances - return confluence_call(*args, **kwargs) - except HTTPError as e: - delay_until = _handle_http_error(e, attempt) - logger.warning( - f"HTTPError in confluence call. " - f"Retrying in {delay_until} seconds..." - ) - while time.monotonic() < delay_until: - # in the future, check a signal here to exit - time.sleep(1) - except AttributeError as e: - # Some error within the Confluence library, unclear why it fails. - # Users reported it to be intermittent, so just retry - if attempt == MAX_RETRIES - 1: - raise e - - logger.exception( - "Confluence Client raised an AttributeError. Retrying..." - ) - time.sleep(5) - - return cast(F, wrapped_call) - - _DEFAULT_PAGINATION_LIMIT = 1000 _MINIMUM_PAGINATION_LIMIT = 50 -class OnyxConfluence(Confluence): +class OnyxConfluence: """ - This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method. + This is a custom Confluence class that: + + A. overrides the default Confluence class to add a custom CQL method. + B. This is necessary because the default Confluence class does not properly support cql expansions. All methods are automatically wrapped with handle_confluence_rate_limit. """ - def __init__(self, url: str, *args: Any, **kwargs: Any) -> None: - super(OnyxConfluence, self).__init__(url, *args, **kwargs) - self._wrap_methods() + CREDENTIAL_PREFIX = "connector:confluence:credential" + CREDENTIAL_TTL = 300 # 5 min - def _wrap_methods(self) -> None: + def __init__( + self, + is_cloud: bool, + url: str, + credentials_provider: CredentialsProviderInterface, + ) -> None: + self._is_cloud = is_cloud + self._url = url.rstrip("/") + self._credentials_provider = credentials_provider + + self.redis_client: Redis | None = None + self.static_credentials: dict[str, Any] | None = None + if self._credentials_provider.is_dynamic(): + self.redis_client = get_redis_client( + tenant_id=credentials_provider.get_tenant_id() + ) + else: + self.static_credentials = self._credentials_provider.get_credentials() + + self._confluence = Confluence(url) + self.credential_key: str = ( + self.CREDENTIAL_PREFIX + + f":credential_{self._credentials_provider.get_provider_key()}" + ) + + self._kwargs: Any = None + + self.shared_base_kwargs = { + "api_version": "cloud" if is_cloud else "latest", + "backoff_and_retry": True, + "cloud": is_cloud, + } + + def _renew_credentials(self) -> tuple[dict[str, Any], bool]: + """credential_json - the current json credentials + Returns a tuple + 1. The up to date credentials + 2. True if the credentials were updated + + This method is intended to be used within a distributed lock. + Lock, call this, update credentials if the tokens were refreshed, then release """ - For each attribute that is callable (i.e., a method) and doesn't start with an underscore, - wrap it with handle_confluence_rate_limit. - """ - for attr_name in dir(self): - if callable(getattr(self, attr_name)) and not attr_name.startswith("_"): - setattr( - self, - attr_name, - handle_confluence_rate_limit(getattr(self, attr_name)), + # static credentials are preloaded, so no locking/redis required + if self.static_credentials: + return self.static_credentials, False + + if not self.redis_client: + raise RuntimeError("self.redis_client is None") + + # dynamic credentials need locking + # check redis first, then fallback to the DB + credential_raw = self.redis_client.get(self.credential_key) + if credential_raw is not None: + credential_bytes = cast(bytes, credential_raw) + credential_str = credential_bytes.decode("utf-8") + credential_json: dict[str, Any] = json.loads(credential_str) + else: + credential_json = self._credentials_provider.get_credentials() + + if "confluence_refresh_token" not in credential_json: + # static credentials ... cache them permanently and return + self.static_credentials = credential_json + return credential_json, False + + # check if we should refresh tokens. we're deciding to refresh halfway + # to expiration + now = datetime.now(timezone.utc) + created_at = datetime.fromisoformat(credential_json["created_at"]) + expires_in: int = credential_json["expires_in"] + renew_at = created_at + timedelta(seconds=expires_in // 2) + if now <= renew_at: + # cached/current credentials are reasonably up to date + return credential_json, False + + # we need to refresh + logger.info("Renewing Confluence Cloud credentials...") + new_credentials = confluence_refresh_tokens( + OAUTH_CONFLUENCE_CLOUD_CLIENT_ID, + OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET, + credential_json["cloud_id"], + credential_json["confluence_refresh_token"], + ) + + # store the new credentials to redis and to the db thru the provider + # redis: we use a 5 min TTL because we are given a 10 minute grace period + # when keys are rotated. it's easier to expire the cached credentials + # reasonably frequently rather than trying to handle strong synchronization + # between the db and redis everywhere the credentials might be updated + new_credential_str = json.dumps(new_credentials) + self.redis_client.set( + self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL + ) + self._credentials_provider.set_credentials(new_credentials) + + return new_credentials, True + + @staticmethod + def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]: + oauth2_dict: dict[str, Any] = {} + if "confluence_refresh_token" in credentials: + oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID + oauth2_dict["token"] = {} + oauth2_dict["token"]["access_token"] = credentials[ + "confluence_access_token" + ] + return oauth2_dict + + def _probe_connection( + self, + **kwargs: Any, + ) -> None: + merged_kwargs = {**self.shared_base_kwargs, **kwargs} + + with self._credentials_provider: + credentials, _ = self._renew_credentials() + + # probe connection with direct client, no retries + if "confluence_refresh_token" in credentials: + logger.info("Probing Confluence with OAuth Access Token.") + + oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict( + credentials ) + url = ( + f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}" + ) + confluence_client_with_minimal_retries = Confluence( + url=url, oauth2=oauth2_dict, **merged_kwargs + ) + else: + logger.info("Probing Confluence with Personal Access Token.") + url = self._url + if self._is_cloud: + confluence_client_with_minimal_retries = Confluence( + url=url, + username=credentials["confluence_username"], + password=credentials["confluence_access_token"], + **merged_kwargs, + ) + else: + confluence_client_with_minimal_retries = Confluence( + url=url, + token=credentials["confluence_access_token"], + **merged_kwargs, + ) + + spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1) + + # uncomment the following for testing + # the following is an attempt to retrieve the user's timezone + # Unfornately, all data is returned in UTC regardless of the user's time zone + # even tho CQL parses incoming times based on the user's time zone + # space_key = spaces["results"][0]["key"] + # space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space") + + if not spaces: + raise RuntimeError( + f"No spaces found at {url}! " + "Check your credentials and wiki_base and make sure " + "is_cloud is set correctly." + ) + + logger.info("Confluence probe succeeded.") + + def _initialize_connection( + self, + **kwargs: Any, + ) -> None: + """Called externally to init the connection in a thread safe manner.""" + merged_kwargs = {**self.shared_base_kwargs, **kwargs} + with self._credentials_provider: + credentials, _ = self._renew_credentials() + self._confluence = self._initialize_connection_helper( + credentials, **merged_kwargs + ) + self._kwargs = merged_kwargs + + def _initialize_connection_helper( + self, + credentials: dict[str, Any], + **kwargs: Any, + ) -> Confluence: + """Called internally to init the connection. Distributed locking + to prevent multiple threads from modifying the credentials + must be handled around this function.""" + + confluence = None + + # probe connection with direct client, no retries + if "confluence_refresh_token" in credentials: + logger.info("Connecting to Confluence Cloud with OAuth Access Token.") + + oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials) + url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}" + confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs) + else: + logger.info("Connecting to Confluence with Personal Access Token.") + if self._is_cloud: + confluence = Confluence( + url=self._url, + username=credentials["confluence_username"], + password=credentials["confluence_access_token"], + **kwargs, + ) + else: + confluence = Confluence( + url=self._url, + token=credentials["confluence_access_token"], + **kwargs, + ) + + return confluence + + # https://developer.atlassian.com/cloud/confluence/rate-limiting/ + # this uses the native rate limiting option provided by the + # confluence client and otherwise applies a simpler set of error handling + def _make_rate_limited_confluence_method( + self, name: str, credential_provider: CredentialsProviderInterface | None + ) -> Callable[..., Any]: + def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: + MAX_RETRIES = 5 + + TIMEOUT = 600 + timeout_at = time.monotonic() + TIMEOUT + + for attempt in range(MAX_RETRIES): + if time.monotonic() > timeout_at: + raise TimeoutError( + f"Confluence call attempts took longer than {TIMEOUT} seconds." + ) + + # we're relying more on the client to rate limit itself + # and applying our own retries in a more specific set of circumstances + try: + if credential_provider: + with credential_provider: + credentials, renewed = self._renew_credentials() + if renewed: + self._confluence = self._initialize_connection_helper( + credentials, **self._kwargs + ) + attr = getattr(self._confluence, name, None) + if attr is None: + # The underlying Confluence client doesn't have this attribute + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + return attr(*args, **kwargs) + else: + attr = getattr(self._confluence, name, None) + if attr is None: + # The underlying Confluence client doesn't have this attribute + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + return attr(*args, **kwargs) + + except HTTPError as e: + delay_until = _handle_http_error(e, attempt) + logger.warning( + f"HTTPError in confluence call. " + f"Retrying in {delay_until} seconds..." + ) + while time.monotonic() < delay_until: + # in the future, check a signal here to exit + time.sleep(1) + except AttributeError as e: + # Some error within the Confluence library, unclear why it fails. + # Users reported it to be intermittent, so just retry + if attempt == MAX_RETRIES - 1: + raise e + + logger.exception( + "Confluence Client raised an AttributeError. Retrying..." + ) + time.sleep(5) + + return wrapped_call + + # def _wrap_methods(self) -> None: + # """ + # For each attribute that is callable (i.e., a method) and doesn't start with an underscore, + # wrap it with handle_confluence_rate_limit. + # """ + # for attr_name in dir(self): + # if callable(getattr(self, attr_name)) and not attr_name.startswith("_"): + # setattr( + # self, + # attr_name, + # handle_confluence_rate_limit(getattr(self, attr_name)), + # ) + + # def _ensure_token_valid(self) -> None: + # if self._token_is_expired(): + # self._refresh_token() + # # Re-init the Confluence client with the originally stored args + # self._confluence = Confluence(self._url, *self._args, **self._kwargs) + + def __getattr__(self, name: str) -> Any: + """Dynamically intercept attribute/method access.""" + attr = getattr(self._confluence, name, None) + if attr is None: + # The underlying Confluence client doesn't have this attribute + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + # If it's not a method, just return it after ensuring token validity + if not callable(attr): + return attr + + # skip methods that start with "_" + if name.startswith("_"): + return attr + + # wrap the method with our retry handler + rate_limited_method: Callable[ + ..., Any + ] = self._make_rate_limited_confluence_method(name, self._credentials_provider) + + def wrapped_method(*args: Any, **kwargs: Any) -> Any: + return rate_limited_method(*args, **kwargs) + + return wrapped_method def _paginate_url( self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False @@ -507,63 +752,212 @@ class OnyxConfluence(Confluence): return response -def _validate_connector_configuration( - credentials: dict[str, Any], - is_cloud: bool, - wiki_base: str, -) -> None: - # test connection with direct client, no retries - confluence_client_with_minimal_retries = Confluence( - api_version="cloud" if is_cloud else "latest", - url=wiki_base.rstrip("/"), - username=credentials["confluence_username"] if is_cloud else None, - password=credentials["confluence_access_token"] if is_cloud else None, - token=credentials["confluence_access_token"] if not is_cloud else None, - backoff_and_retry=True, - max_backoff_retries=6, - max_backoff_seconds=10, +def get_user_email_from_username__server( + confluence_client: OnyxConfluence, user_name: str +) -> str | None: + global _USER_EMAIL_CACHE + if _USER_EMAIL_CACHE.get(user_name) is None: + try: + response = confluence_client.get_mobile_parameters(user_name) + email = response.get("email") + except Exception: + logger.warning(f"failed to get confluence email for {user_name}") + # For now, we'll just return None and log a warning. This means + # we will keep retrying to get the email every group sync. + email = None + # We may want to just return a string that indicates failure so we dont + # keep retrying + # email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}" + _USER_EMAIL_CACHE[user_name] = email + return _USER_EMAIL_CACHE[user_name] + + +def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str: + """Get Confluence Display Name based on the account-id or userkey value + + Args: + user_id (str): The user id (i.e: the account-id or userkey) + confluence_client (Confluence): The Confluence Client + + Returns: + str: The User Display Name. 'Unknown User' if the user is deactivated or not found + """ + global _USER_ID_TO_DISPLAY_NAME_CACHE + if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None: + try: + result = confluence_client.get_user_details_by_userkey(user_id) + found_display_name = result.get("displayName") + except Exception: + found_display_name = None + + if not found_display_name: + try: + result = confluence_client.get_user_details_by_accountid(user_id) + found_display_name = result.get("displayName") + except Exception: + found_display_name = None + + _USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name + + return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND + + +def attachment_to_content( + confluence_client: OnyxConfluence, + attachment: dict[str, Any], + parent_content_id: str | None = None, +) -> str | None: + """If it returns None, assume that we should skip this attachment.""" + if not validate_attachment_filetype(attachment): + return None + + if "api.atlassian.com" in confluence_client.url: + # https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get + if not parent_content_id: + logger.warning( + "parent_content_id is required to download attachments from Confluence Cloud!" + ) + return None + + download_link = ( + confluence_client.url + + f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download" + ) + else: + download_link = confluence_client.url + attachment["_links"]["download"] + + attachment_size = attachment["extensions"]["fileSize"] + if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD: + logger.warning( + f"Skipping {download_link} due to size. " + f"size={attachment_size} " + f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}" + ) + return None + + logger.info(f"_attachment_to_content - _session.get: link={download_link}") + + # why are we using session.get here? we probably won't retry these ... is that ok? + response = confluence_client._session.get(download_link) + if response.status_code != 200: + logger.warning( + f"Failed to fetch {download_link} with invalid status code {response.status_code}" + ) + return None + + extracted_text = extract_file_text( + io.BytesIO(response.content), + file_name=attachment["title"], + break_on_unprocessable=False, ) - spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1) + if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD: + logger.warning( + f"Skipping {download_link} due to char count. " + f"char count={len(extracted_text)} " + f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}" + ) + return None - # uncomment the following for testing - # the following is an attempt to retrieve the user's timezone - # Unfornately, all data is returned in UTC regardless of the user's time zone - # even tho CQL parses incoming times based on the user's time zone - # space_key = spaces["results"][0]["key"] - # space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space") + return extracted_text - if not spaces: - raise RuntimeError( - f"No spaces found at {wiki_base}! " - "Check your credentials and wiki_base and make sure " - "is_cloud is set correctly." + +def extract_text_from_confluence_html( + confluence_client: OnyxConfluence, + confluence_object: dict[str, Any], + fetched_titles: set[str], +) -> str: + """Parse a Confluence html page and replace the 'user Id' by the real + User Display Name + + Args: + confluence_object (dict): The confluence object as a dict + confluence_client (Confluence): Confluence client + fetched_titles (set[str]): The titles of the pages that have already been fetched + Returns: + str: loaded and formated Confluence page + """ + body = confluence_object["body"] + object_html = body.get("storage", body.get("view", {})).get("value") + + soup = bs4.BeautifulSoup(object_html, "html.parser") + for user in soup.findAll("ri:user"): + user_id = ( + user.attrs["ri:account-id"] + if "ri:account-id" in user.attrs + else user.get("ri:userkey") + ) + if not user_id: + logger.warning( + "ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}" + ) + continue + # Include @ sign for tagging, more clear for LLM + user.replaceWith("@" + _get_user(confluence_client, user_id)) + + for html_page_reference in soup.findAll("ac:structured-macro"): + # Here, we only want to process page within page macros + if html_page_reference.attrs.get("ac:name") != "include": + continue + + page_data = html_page_reference.find("ri:page") + if not page_data: + logger.warning( + f"Skipping retrieval of {html_page_reference} because because page data is missing" + ) + continue + + page_title = page_data.attrs.get("ri:content-title") + if not page_title: + # only fetch pages that have a title + logger.warning( + f"Skipping retrieval of {html_page_reference} because it has no title" + ) + continue + + if page_title in fetched_titles: + # prevent recursive fetching of pages + logger.debug(f"Skipping {page_title} because it has already been fetched") + continue + + fetched_titles.add(page_title) + + # Wrap this in a try-except because there are some pages that might not exist + try: + page_query = f"type=page and title='{quote(page_title)}'" + + page_contents: dict[str, Any] | None = None + # Confluence enforces title uniqueness, so we should only get one result here + for page in confluence_client.paginated_cql_retrieval( + cql=page_query, + expand="body.storage.value", + limit=1, + ): + page_contents = page + break + except Exception as e: + logger.warning( + f"Error getting page contents for object {confluence_object}: {e}" + ) + continue + + if not page_contents: + continue + + text_from_page = extract_text_from_confluence_html( + confluence_client=confluence_client, + confluence_object=page_contents, + fetched_titles=fetched_titles, ) + html_page_reference.replaceWith(text_from_page) -def build_confluence_client( - credentials: dict[str, Any], - is_cloud: bool, - wiki_base: str, -) -> OnyxConfluence: - try: - _validate_connector_configuration( - credentials=credentials, - is_cloud=is_cloud, - wiki_base=wiki_base, - ) - except Exception as e: - raise ConnectorValidationError(str(e)) + for html_link_body in soup.findAll("ac:link-body"): + # This extracts the text from inline links in the page so they can be + # represented in the document text as plain text + try: + text_from_link = html_link_body.text + html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})") + except Exception as e: + logger.warning(f"Error processing ac:link-body: {e}") - return OnyxConfluence( - api_version="cloud" if is_cloud else "latest", - # Remove trailing slash from wiki_base if present - url=wiki_base.rstrip("/"), - # passing in username causes issues for Confluence data center - username=credentials["confluence_username"] if is_cloud else None, - password=credentials["confluence_access_token"] if is_cloud else None, - token=credentials["confluence_access_token"] if not is_cloud else None, - backoff_and_retry=True, - max_backoff_retries=10, - max_backoff_seconds=60, - cloud=is_cloud, - ) + return format_document_soup(soup) diff --git a/backend/onyx/connectors/confluence/utils.py b/backend/onyx/connectors/confluence/utils.py index b77696645..801e24d4a 100644 --- a/backend/onyx/connectors/confluence/utils.py +++ b/backend/onyx/connectors/confluence/utils.py @@ -1,185 +1,38 @@ -import io +import math +import time +from collections.abc import Callable from datetime import datetime +from datetime import timedelta from datetime import timezone from typing import Any +from typing import cast from typing import TYPE_CHECKING +from typing import TypeVar from urllib.parse import parse_qs from urllib.parse import quote from urllib.parse import urlparse import bs4 +import requests +from pydantic import BaseModel -from onyx.configs.app_configs import ( - CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD, -) -from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD -from onyx.file_processing.extract_file_text import extract_file_text -from onyx.file_processing.html_utils import format_document_soup from onyx.utils.logger import setup_logger if TYPE_CHECKING: - from onyx.connectors.confluence.onyx_confluence import OnyxConfluence + pass logger = setup_logger() - -_USER_EMAIL_CACHE: dict[str, str | None] = {} +CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token" +RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() -def get_user_email_from_username__server( - confluence_client: "OnyxConfluence", user_name: str -) -> str | None: - global _USER_EMAIL_CACHE - if _USER_EMAIL_CACHE.get(user_name) is None: - try: - response = confluence_client.get_mobile_parameters(user_name) - email = response.get("email") - except Exception: - logger.warning(f"failed to get confluence email for {user_name}") - # For now, we'll just return None and log a warning. This means - # we will keep retrying to get the email every group sync. - email = None - # We may want to just return a string that indicates failure so we dont - # keep retrying - # email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}" - _USER_EMAIL_CACHE[user_name] = email - return _USER_EMAIL_CACHE[user_name] - - -_USER_NOT_FOUND = "Unknown Confluence User" -_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {} - - -def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str: - """Get Confluence Display Name based on the account-id or userkey value - - Args: - user_id (str): The user id (i.e: the account-id or userkey) - confluence_client (Confluence): The Confluence Client - - Returns: - str: The User Display Name. 'Unknown User' if the user is deactivated or not found - """ - global _USER_ID_TO_DISPLAY_NAME_CACHE - if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None: - try: - result = confluence_client.get_user_details_by_userkey(user_id) - found_display_name = result.get("displayName") - except Exception: - found_display_name = None - - if not found_display_name: - try: - result = confluence_client.get_user_details_by_accountid(user_id) - found_display_name = result.get("displayName") - except Exception: - found_display_name = None - - _USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name - - return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND - - -def extract_text_from_confluence_html( - confluence_client: "OnyxConfluence", - confluence_object: dict[str, Any], - fetched_titles: set[str], -) -> str: - """Parse a Confluence html page and replace the 'user Id' by the real - User Display Name - - Args: - confluence_object (dict): The confluence object as a dict - confluence_client (Confluence): Confluence client - fetched_titles (set[str]): The titles of the pages that have already been fetched - Returns: - str: loaded and formated Confluence page - """ - body = confluence_object["body"] - object_html = body.get("storage", body.get("view", {})).get("value") - - soup = bs4.BeautifulSoup(object_html, "html.parser") - for user in soup.findAll("ri:user"): - user_id = ( - user.attrs["ri:account-id"] - if "ri:account-id" in user.attrs - else user.get("ri:userkey") - ) - if not user_id: - logger.warning( - "ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}" - ) - continue - # Include @ sign for tagging, more clear for LLM - user.replaceWith("@" + _get_user(confluence_client, user_id)) - - for html_page_reference in soup.findAll("ac:structured-macro"): - # Here, we only want to process page within page macros - if html_page_reference.attrs.get("ac:name") != "include": - continue - - page_data = html_page_reference.find("ri:page") - if not page_data: - logger.warning( - f"Skipping retrieval of {html_page_reference} because because page data is missing" - ) - continue - - page_title = page_data.attrs.get("ri:content-title") - if not page_title: - # only fetch pages that have a title - logger.warning( - f"Skipping retrieval of {html_page_reference} because it has no title" - ) - continue - - if page_title in fetched_titles: - # prevent recursive fetching of pages - logger.debug(f"Skipping {page_title} because it has already been fetched") - continue - - fetched_titles.add(page_title) - - # Wrap this in a try-except because there are some pages that might not exist - try: - page_query = f"type=page and title='{quote(page_title)}'" - - page_contents: dict[str, Any] | None = None - # Confluence enforces title uniqueness, so we should only get one result here - for page in confluence_client.paginated_cql_retrieval( - cql=page_query, - expand="body.storage.value", - limit=1, - ): - page_contents = page - break - except Exception as e: - logger.warning( - f"Error getting page contents for object {confluence_object}: {e}" - ) - continue - - if not page_contents: - continue - - text_from_page = extract_text_from_confluence_html( - confluence_client=confluence_client, - confluence_object=page_contents, - fetched_titles=fetched_titles, - ) - - html_page_reference.replaceWith(text_from_page) - - for html_link_body in soup.findAll("ac:link-body"): - # This extracts the text from inline links in the page so they can be - # represented in the document text as plain text - try: - text_from_link = html_link_body.text - html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})") - except Exception as e: - logger.warning(f"Error processing ac:link-body: {e}") - - return format_document_soup(soup) +class TokenResponse(BaseModel): + access_token: str + expires_in: int + token_type: str + refresh_token: str + scope: str def validate_attachment_filetype(attachment: dict[str, Any]) -> bool: @@ -193,49 +46,6 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool: ] -def attachment_to_content( - confluence_client: "OnyxConfluence", - attachment: dict[str, Any], -) -> str | None: - """If it returns None, assume that we should skip this attachment.""" - if not validate_attachment_filetype(attachment): - return None - - download_link = confluence_client.url + attachment["_links"]["download"] - - attachment_size = attachment["extensions"]["fileSize"] - if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD: - logger.warning( - f"Skipping {download_link} due to size. " - f"size={attachment_size} " - f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}" - ) - return None - - logger.info(f"_attachment_to_content - _session.get: link={download_link}") - response = confluence_client._session.get(download_link) - if response.status_code != 200: - logger.warning( - f"Failed to fetch {download_link} with invalid status code {response.status_code}" - ) - return None - - extracted_text = extract_file_text( - io.BytesIO(response.content), - file_name=attachment["title"], - break_on_unprocessable=False, - ) - if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD: - logger.warning( - f"Skipping {download_link} due to char count. " - f"char count={len(extracted_text)} " - f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}" - ) - return None - - return extracted_text - - def build_confluence_document_id( base_url: str, content_url: str, is_cloud: bool ) -> str: @@ -284,6 +94,137 @@ def datetime_from_string(datetime_string: str) -> datetime: return datetime_object +def confluence_refresh_tokens( + client_id: str, client_secret: str, cloud_id: str, refresh_token: str +) -> dict[str, Any]: + # rotate the refresh and access token + # Note that access tokens are only good for an hour in confluence cloud, + # so we're going to have problems if the connector runs for longer + # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair + response = requests.post( + CONFLUENCE_OAUTH_TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "grant_type": "refresh_token", + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + }, + ) + + try: + token_response = TokenResponse.model_validate_json(response.text) + except Exception: + raise RuntimeError("Confluence Cloud token refresh failed.") + + now = datetime.now(timezone.utc) + expires_at = now + timedelta(seconds=token_response.expires_in) + + new_credentials: dict[str, Any] = {} + new_credentials["confluence_access_token"] = token_response.access_token + new_credentials["confluence_refresh_token"] = token_response.refresh_token + new_credentials["created_at"] = now.isoformat() + new_credentials["expires_at"] = expires_at.isoformat() + new_credentials["expires_in"] = token_response.expires_in + new_credentials["scope"] = token_response.scope + new_credentials["cloud_id"] = cloud_id + return new_credentials + + +F = TypeVar("F", bound=Callable[..., Any]) + + +# https://developer.atlassian.com/cloud/confluence/rate-limiting/ +# this uses the native rate limiting option provided by the +# confluence client and otherwise applies a simpler set of error handling +def handle_confluence_rate_limit(confluence_call: F) -> F: + def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: + MAX_RETRIES = 5 + + TIMEOUT = 600 + timeout_at = time.monotonic() + TIMEOUT + + for attempt in range(MAX_RETRIES): + if time.monotonic() > timeout_at: + raise TimeoutError( + f"Confluence call attempts took longer than {TIMEOUT} seconds." + ) + + try: + # we're relying more on the client to rate limit itself + # and applying our own retries in a more specific set of circumstances + return confluence_call(*args, **kwargs) + except requests.HTTPError as e: + delay_until = _handle_http_error(e, attempt) + logger.warning( + f"HTTPError in confluence call. " + f"Retrying in {delay_until} seconds..." + ) + while time.monotonic() < delay_until: + # in the future, check a signal here to exit + time.sleep(1) + except AttributeError as e: + # Some error within the Confluence library, unclear why it fails. + # Users reported it to be intermittent, so just retry + if attempt == MAX_RETRIES - 1: + raise e + + logger.exception( + "Confluence Client raised an AttributeError. Retrying..." + ) + time.sleep(5) + + return cast(F, wrapped_call) + + +def _handle_http_error(e: requests.HTTPError, attempt: int) -> int: + MIN_DELAY = 2 + MAX_DELAY = 60 + STARTING_DELAY = 5 + BACKOFF = 2 + + # Check if the response or headers are None to avoid potential AttributeError + if e.response is None or e.response.headers is None: + logger.warning("HTTPError with `None` as response or as headers") + raise e + + if ( + e.response.status_code != 429 + and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower() + ): + raise e + + retry_after = None + + retry_after_header = e.response.headers.get("Retry-After") + if retry_after_header is not None: + try: + retry_after = int(retry_after_header) + if retry_after > MAX_DELAY: + logger.warning( + f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..." + ) + retry_after = MAX_DELAY + if retry_after < MIN_DELAY: + retry_after = MIN_DELAY + except ValueError: + pass + + if retry_after is not None: + logger.warning( + f"Rate limiting with retry header. Retrying after {retry_after} seconds..." + ) + delay = retry_after + else: + logger.warning( + "Rate limiting without retry header. Retrying with exponential backoff..." + ) + delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY) + + delay_until = math.ceil(time.monotonic() + delay) + return delay_until + + def get_single_param_from_url(url: str, param: str) -> str | None: """Get a parameter from a url""" parsed_url = urlparse(url) diff --git a/backend/onyx/connectors/credentials_provider.py b/backend/onyx/connectors/credentials_provider.py new file mode 100644 index 000000000..2ba7e6a6f --- /dev/null +++ b/backend/onyx/connectors/credentials_provider.py @@ -0,0 +1,135 @@ +import uuid +from types import TracebackType +from typing import Any + +from redis.lock import Lock as RedisLock +from sqlalchemy import select + +from onyx.connectors.interfaces import CredentialsProviderInterface +from onyx.db.engine import get_session_with_tenant +from onyx.db.models import Credential +from onyx.redis.redis_pool import get_redis_client + + +class OnyxDBCredentialsProvider( + CredentialsProviderInterface["OnyxDBCredentialsProvider"] +): + """Implementation to allow the connector to callback and update credentials in the db. + Required in cases where credentials can rotate while the connector is running. + """ + + LOCK_TTL = 900 # TTL of the lock + + def __init__(self, tenant_id: str, connector_name: str, credential_id: int): + self._tenant_id = tenant_id + self._connector_name = connector_name + self._credential_id = credential_id + + self.redis_client = get_redis_client(tenant_id=tenant_id) + + # lock used to prevent overlapping renewal of credentials + self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}" + self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL) + + def __enter__(self) -> "OnyxDBCredentialsProvider": + acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL) + if not acquired: + raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}") + + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """Release the lock when exiting the context.""" + if self._lock and self._lock.owned(): + self._lock.release() + + def get_tenant_id(self) -> str | None: + return self._tenant_id + + def get_provider_key(self) -> str: + return str(self._credential_id) + + def get_credentials(self) -> dict[str, Any]: + with get_session_with_tenant(tenant_id=self._tenant_id) as db_session: + credential = db_session.execute( + select(Credential).where(Credential.id == self._credential_id) + ).scalar_one() + + if credential is None: + raise ValueError( + f"No credential found: credential={self._credential_id}" + ) + + return credential.credential_json + + def set_credentials(self, credential_json: dict[str, Any]) -> None: + with get_session_with_tenant(tenant_id=self._tenant_id) as db_session: + try: + credential = db_session.execute( + select(Credential) + .where(Credential.id == self._credential_id) + .with_for_update() + ).scalar_one() + + if credential is None: + raise ValueError( + f"No credential found: credential={self._credential_id}" + ) + + credential.credential_json = credential_json + db_session.commit() + except Exception: + db_session.rollback() + raise + + def is_dynamic(self) -> bool: + return True + + +class OnyxStaticCredentialsProvider( + CredentialsProviderInterface["OnyxStaticCredentialsProvider"] +): + """Implementation (a very simple one!) to handle static credentials.""" + + def __init__( + self, + tenant_id: str | None, + connector_name: str, + credential_json: dict[str, Any], + ): + self._tenant_id = tenant_id + self._connector_name = connector_name + self._credential_json = credential_json + + self._provider_key = str(uuid.uuid4()) + + def __enter__(self) -> "OnyxStaticCredentialsProvider": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def get_tenant_id(self) -> str | None: + return self._tenant_id + + def get_provider_key(self) -> str: + return self._provider_key + + def get_credentials(self) -> dict[str, Any]: + return self._credential_json + + def set_credentials(self, credential_json: dict[str, Any]) -> None: + self._credential_json = credential_json + + def is_dynamic(self) -> bool: + return False diff --git a/backend/onyx/connectors/factory.py b/backend/onyx/connectors/factory.py index 14221d2e3..73593cc60 100644 --- a/backend/onyx/connectors/factory.py +++ b/backend/onyx/connectors/factory.py @@ -12,6 +12,7 @@ from onyx.connectors.blob.connector import BlobStorageConnector from onyx.connectors.bookstack.connector import BookstackConnector from onyx.connectors.clickup.connector import ClickupConnector from onyx.connectors.confluence.connector import ConfluenceConnector +from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.connectors.discord.connector import DiscordConnector from onyx.connectors.discourse.connector import DiscourseConnector from onyx.connectors.document360.connector import Document360Connector @@ -32,6 +33,7 @@ from onyx.connectors.guru.connector import GuruConnector from onyx.connectors.hubspot.connector import HubSpotConnector from onyx.connectors.interfaces import BaseConnector from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CredentialsConnector from onyx.connectors.interfaces import EventConnector from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector @@ -57,6 +59,7 @@ from onyx.db.connector import fetch_connector_by_id from onyx.db.credentials import backend_update_credential_json from onyx.db.credentials import fetch_credential_by_id from onyx.db.models import Credential +from shared_configs.contextvars import get_current_tenant_id class ConnectorMissingException(Exception): @@ -167,10 +170,17 @@ def instantiate_connector( connector_class = identify_connector_class(source, input_type) connector = connector_class(**connector_specific_config) - new_credentials = connector.load_credentials(credential.credential_json) - if new_credentials is not None: - backend_update_credential_json(credential, new_credentials, db_session) + if isinstance(connector, CredentialsConnector): + provider = OnyxDBCredentialsProvider( + get_current_tenant_id(), str(source), credential.id + ) + connector.set_credentials_provider(provider) + else: + new_credentials = connector.load_credentials(credential.credential_json) + + if new_credentials is not None: + backend_update_credential_json(credential, new_credentials, db_session) return connector diff --git a/backend/onyx/connectors/interfaces.py b/backend/onyx/connectors/interfaces.py index 8516d08a3..0b2f8b661 100644 --- a/backend/onyx/connectors/interfaces.py +++ b/backend/onyx/connectors/interfaces.py @@ -1,7 +1,10 @@ import abc from collections.abc import Generator from collections.abc import Iterator +from types import TracebackType from typing import Any +from typing import Generic +from typing import TypeVar from pydantic import BaseModel @@ -111,6 +114,69 @@ class OAuthConnector(BaseConnector): raise NotImplementedError +T = TypeVar("T", bound="CredentialsProviderInterface") + + +class CredentialsProviderInterface(abc.ABC, Generic[T]): + @abc.abstractmethod + def __enter__(self) -> T: + raise NotImplementedError + + @abc.abstractmethod + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + raise NotImplementedError + + @abc.abstractmethod + def get_tenant_id(self) -> str | None: + raise NotImplementedError + + @abc.abstractmethod + def get_provider_key(self) -> str: + """a unique key that the connector can use to lock around a credential + that might be used simultaneously. + + Will typically be the credential id, but can also just be something random + in cases when there is nothing to lock (aka static credentials) + """ + raise NotImplementedError + + @abc.abstractmethod + def get_credentials(self) -> dict[str, Any]: + raise NotImplementedError + + @abc.abstractmethod + def set_credentials(self, credential_json: dict[str, Any]) -> None: + raise NotImplementedError + + @abc.abstractmethod + def is_dynamic(self) -> bool: + """If dynamic, the credentials may change during usage ... maening the client + needs to use the locking features of the credentials provider to operate + correctly. + + If static, the client can simply reference the credentials once and use them + through the entire indexing run. + """ + raise NotImplementedError + + +class CredentialsConnector(BaseConnector): + """Implement this if the connector needs to be able to read and write credentials + on the fly. Typically used with shared credentials/tokens that might be renewed + at any time.""" + + @abc.abstractmethod + def set_credentials_provider( + self, credentials_provider: CredentialsProviderInterface + ) -> None: + raise NotImplementedError + + # Event driven class EventConnector(BaseConnector): @abc.abstractmethod diff --git a/backend/onyx/connectors/web/connector.py b/backend/onyx/connectors/web/connector.py index 9bd6e2073..922563374 100644 --- a/backend/onyx/connectors/web/connector.py +++ b/backend/onyx/connectors/web/connector.py @@ -302,29 +302,29 @@ class WebConnector(LoadConnector): playwright, context = start_playwright() restart_playwright = False while to_visit: - current_url = to_visit.pop() - if current_url in visited_links: + initial_url = to_visit.pop() + if initial_url in visited_links: continue - visited_links.add(current_url) + visited_links.add(initial_url) try: - protected_url_check(current_url) + protected_url_check(initial_url) except Exception as e: - last_error = f"Invalid URL {current_url} due to {e}" + last_error = f"Invalid URL {initial_url} due to {e}" logger.warning(last_error) continue - logger.info(f"Visiting {current_url}") + logger.info(f"{len(visited_links)}: Visiting {initial_url}") try: - check_internet_connection(current_url) + check_internet_connection(initial_url) if restart_playwright: playwright, context = start_playwright() restart_playwright = False - if current_url.split(".")[-1] == "pdf": + if initial_url.split(".")[-1] == "pdf": # PDF files are not checked for links - response = requests.get(current_url) + response = requests.get(initial_url) page_text, metadata = read_pdf_file( file=io.BytesIO(response.content) ) @@ -332,10 +332,10 @@ class WebConnector(LoadConnector): doc_batch.append( Document( - id=current_url, - sections=[Section(link=current_url, text=page_text)], + id=initial_url, + sections=[Section(link=initial_url, text=page_text)], source=DocumentSource.WEB, - semantic_identifier=current_url.split("/")[-1], + semantic_identifier=initial_url.split("/")[-1], metadata=metadata, doc_updated_at=_get_datetime_from_last_modified_header( last_modified @@ -347,21 +347,25 @@ class WebConnector(LoadConnector): continue page = context.new_page() - page_response = page.goto(current_url) + page_response = page.goto(initial_url) last_modified = ( page_response.header_value("Last-Modified") if page_response else None ) - final_page = page.url - if final_page != current_url: - logger.info(f"Redirected to {final_page}") - protected_url_check(final_page) - current_url = final_page - if current_url in visited_links: - logger.info("Redirected page already indexed") + final_url = page.url + if final_url != initial_url: + protected_url_check(final_url) + initial_url = final_url + if initial_url in visited_links: + logger.info( + f"{len(visited_links)}: {initial_url} redirected to {final_url} - already indexed" + ) continue - visited_links.add(current_url) + logger.info( + f"{len(visited_links)}: {initial_url} redirected to {final_url}" + ) + visited_links.add(initial_url) if self.scroll_before_scraping: scroll_attempts = 0 @@ -379,13 +383,13 @@ class WebConnector(LoadConnector): soup = BeautifulSoup(content, "html.parser") if self.recursive: - internal_links = get_internal_links(base_url, current_url, soup) + internal_links = get_internal_links(base_url, initial_url, soup) for link in internal_links: if link not in visited_links: to_visit.append(link) if page_response and str(page_response.status)[0] in ("4", "5"): - last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response" + last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response" logger.info(last_error) continue @@ -393,12 +397,12 @@ class WebConnector(LoadConnector): doc_batch.append( Document( - id=current_url, + id=initial_url, sections=[ - Section(link=current_url, text=parsed_html.cleaned_text) + Section(link=initial_url, text=parsed_html.cleaned_text) ], source=DocumentSource.WEB, - semantic_identifier=parsed_html.title or current_url, + semantic_identifier=parsed_html.title or initial_url, metadata={}, doc_updated_at=_get_datetime_from_last_modified_header( last_modified @@ -410,7 +414,7 @@ class WebConnector(LoadConnector): page.close() except Exception as e: - last_error = f"Failed to fetch '{current_url}': {e}" + last_error = f"Failed to fetch '{initial_url}': {e}" logger.exception(last_error) playwright.stop() restart_playwright = True diff --git a/backend/onyx/main.py b/backend/onyx/main.py index 2444e6f19..003e26fb2 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router from onyx.server.documents.connector import router as connector_router from onyx.server.documents.credential import router as credential_router from onyx.server.documents.document import router as document_router -from onyx.server.documents.standard_oauth import router as oauth_router from onyx.server.features.document_set.api import router as document_set_router from onyx.server.features.folder.api import router as folder_router from onyx.server.features.input_prompt.api import ( @@ -323,7 +322,6 @@ def get_application() -> FastAPI: ) include_router_with_global_prefix_prepended(application, long_term_logs_router) include_router_with_global_prefix_prepended(application, api_key_router) - include_router_with_global_prefix_prepended(application, oauth_router) if AUTH_TYPE == AuthType.DISABLED: # Server logs this during auth setup verification step diff --git a/backend/onyx/server/utils.py b/backend/onyx/server/utils.py index 8dc7a429b..8d8643a51 100644 --- a/backend/onyx/server/utils.py +++ b/backend/onyx/server/utils.py @@ -46,13 +46,21 @@ def mask_string(sensitive_str: str) -> str: return "****...**" + sensitive_str[-4:] +MASK_CREDENTIALS_WHITELIST = { + DB_CREDENTIALS_AUTHENTICATION_METHOD, + "wiki_base", + "cloud_name", + "cloud_id", +} + + def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]: masked_creds = {} for key, val in credential_dict.items(): if isinstance(val, str): # we want to pass the authentication_method field through so the frontend # can disambiguate credentials created by different methods - if key == DB_CREDENTIALS_AUTHENTICATION_METHOD: + if key in MASK_CREDENTIALS_WHITELIST: masked_creds[key] = val else: masked_creds[key] = mask_string(val) @@ -63,8 +71,8 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]: continue raise ValueError( - f"Unable to mask credentials of type other than string, cannot process request." - f"Recieved type: {type(val)}" + f"Unable to mask credentials of type other than string or int, cannot process request." + f"Received type: {type(val)}" ) return masked_creds diff --git a/backend/tests/daily/connectors/confluence/test_confluence_basic.py b/backend/tests/daily/connectors/confluence/test_confluence_basic.py index 26d86c557..7cc80fb2f 100644 --- a/backend/tests/daily/connectors/confluence/test_confluence_basic.py +++ b/backend/tests/daily/connectors/confluence/test_confluence_basic.py @@ -5,7 +5,9 @@ from unittest.mock import patch import pytest +from onyx.configs.constants import DocumentSource from onyx.connectors.confluence.connector import ConfluenceConnector +from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider from onyx.connectors.models import Document @@ -18,12 +20,15 @@ def confluence_connector() -> ConfluenceConnector: page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""), ) - connector.load_credentials( + credentials_provider = OnyxStaticCredentialsProvider( + None, + DocumentSource.CONFLUENCE, { "confluence_username": os.environ["CONFLUENCE_USER_NAME"], "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], - } + }, ) + connector.set_credentials_provider(credentials_provider) return connector diff --git a/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py b/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py index 0f66a993d..6bb43437e 100644 --- a/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py +++ b/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py @@ -2,7 +2,9 @@ import os import pytest +from onyx.configs.constants import DocumentSource from onyx.connectors.confluence.connector import ConfluenceConnector +from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider @pytest.fixture @@ -11,12 +13,16 @@ def confluence_connector() -> ConfluenceConnector: wiki_base="https://danswerai.atlassian.net", is_cloud=True, ) - connector.load_credentials( + + credentials_provider = OnyxStaticCredentialsProvider( + None, + DocumentSource.CONFLUENCE, { - "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], "confluence_username": os.environ["CONFLUENCE_USER_NAME"], - } + "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], + }, ) + connector.set_credentials_provider(credentials_provider) return connector diff --git a/backend/tests/unit/onyx/connectors/confluence/test_rate_limit_handler.py b/backend/tests/unit/onyx/connectors/confluence/test_rate_limit_handler.py index ed77d7764..c7e88e5dd 100644 --- a/backend/tests/unit/onyx/connectors/confluence/test_rate_limit_handler.py +++ b/backend/tests/unit/onyx/connectors/confluence/test_rate_limit_handler.py @@ -3,9 +3,7 @@ from unittest.mock import Mock import pytest from requests import HTTPError -from onyx.connectors.confluence.onyx_confluence import ( - handle_confluence_rate_limit, -) +from onyx.connectors.confluence.utils import handle_confluence_rate_limit @pytest.fixture @@ -50,6 +48,8 @@ def mock_confluence_call() -> Mock: # mock_sleep.assert_called_with(int(retry_after)) +# NOTE(rkuo): This tests an older version of rate limiting that is being deprecated +# and probably should go away soon. def test_non_rate_limit_error(mock_confluence_call: Mock) -> None: mock_confluence_call.side_effect = HTTPError( response=Mock(status_code=500, text="Internal Server Error") diff --git a/web/src/app/admin/assistants/LabelManagement.tsx b/web/src/app/admin/assistants/LabelManagement.tsx index bdb3690cd..7dc9eaff5 100644 --- a/web/src/app/admin/assistants/LabelManagement.tsx +++ b/web/src/app/admin/assistants/LabelManagement.tsx @@ -100,11 +100,6 @@ export default function LabelManagement() { width="w-full max-w-xs" name={`editLabelName_${label.id}`} label="Label Name" - value={ - values.editLabelId === label.id - ? values.editLabelName - : label.name - } onChange={(e) => { setFieldValue("editLabelId", label.id); setFieldValue("editLabelName", e.target.value); diff --git a/web/src/app/admin/assistants/StarterMessageList.tsx b/web/src/app/admin/assistants/StarterMessageList.tsx index b7471269c..f44a59eb1 100644 --- a/web/src/app/admin/assistants/StarterMessageList.tsx +++ b/web/src/app/admin/assistants/StarterMessageList.tsx @@ -52,7 +52,6 @@ export default function StarterMessagesList({ handleInputChange(index, e.target.value)} className="flex-grow" removeLabel diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index 44eec87df..4524a32df 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -193,13 +193,15 @@ export default function AddConnector({ // Check if there are no credentials const noCredentials = credentialTemplate == null; - if (noCredentials && 1 != formStep) { - setFormStep(Math.max(1, formStep)); - } + useEffect(() => { + if (noCredentials && 1 != formStep) { + setFormStep(Math.max(1, formStep)); + } - if (!noCredentials && !credentialActivated && formStep != 0) { - setFormStep(Math.min(formStep, 0)); - } + if (!noCredentials && !credentialActivated && formStep != 0) { + setFormStep(Math.min(formStep, 0)); + } + }, [noCredentials, formStep, setFormStep]); const convertStringToDateTime = (indexingStart: string | null) => { return indexingStart ? new Date(indexingStart) : null; diff --git a/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx b/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx index 8032c0d7b..eb32c7a87 100644 --- a/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx +++ b/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx @@ -33,7 +33,7 @@ export default function OAuthCallbackPage() { const connector = pathname?.split("/")[3]; useEffect(() => { - const handleOAuthCallback = async () => { + const onFirstLoad = async () => { // Examples // connector (url segment)= "google-drive" // sourceType (for looking up metadata) = "google_drive" @@ -85,10 +85,19 @@ export default function OAuthCallbackPage() { } setStatusMessage("Success!"); - setStatusDetails( - `Your authorization with ${sourceMetadata.displayName} completed successfully.` - ); - setRedirectUrl(response.redirect_on_success); // Extract the redirect URL + + // set the continuation link + if (response.finalize_url) { + setRedirectUrl(response.finalize_url); + setStatusDetails( + `Your authorization with ${sourceMetadata.displayName} completed successfully. Additional steps are required to complete credential setup.` + ); + } else { + setRedirectUrl(response.redirect_on_success); + setStatusDetails( + `Your authorization with ${sourceMetadata.displayName} completed successfully.` + ); + } setIsError(false); } catch (error) { console.error("OAuth error:", error); @@ -100,15 +109,15 @@ export default function OAuthCallbackPage() { } }; - handleOAuthCallback(); + onFirstLoad(); }, [code, state, connector]); return ( -
+
} /> -
- +
+

{statusMessage}

{statusDetails}

{redirectUrl && !isError && ( diff --git a/web/src/app/admin/connectors/[connector]/oauth/finalize/page.tsx b/web/src/app/admin/connectors/[connector]/oauth/finalize/page.tsx new file mode 100644 index 000000000..f4e24473c --- /dev/null +++ b/web/src/app/admin/connectors/[connector]/oauth/finalize/page.tsx @@ -0,0 +1,293 @@ +"use client"; + +import { useEffect, useState } from "react"; +import { usePathname, useRouter, useSearchParams } from "next/navigation"; +import { AdminPageTitle } from "@/components/admin/Title"; +import { Button } from "@/components/ui/button"; +import Title from "@/components/ui/title"; +import { KeyIcon } from "@/components/icons/icons"; +import { getSourceMetadata, isValidSource } from "@/lib/sources"; +import { ConfluenceAccessibleResource, ValidSources } from "@/lib/types"; +import CardSection from "@/components/admin/CardSection"; +import { + handleOAuthAuthorizationResponse, + handleOAuthConfluenceFinalize, + handleOAuthPrepareFinalization, +} from "@/lib/oauth_utils"; +import { SelectorFormField } from "@/components/admin/connectors/Field"; +import { ErrorMessage, Field, Form, Formik, useFormikContext } from "formik"; +import * as Yup from "yup"; + +// Helper component to keep the effect logic clean: +function UpdateCloudURLOnCloudIdChange({ + accessibleResources, +}: { + accessibleResources: ConfluenceAccessibleResource[]; +}) { + const { values, setValues, setFieldValue } = useFormikContext<{ + cloud_id: string; + cloud_name: string; + cloud_url: string; + }>(); + + useEffect(() => { + // Whenever cloud_id changes, find the matching resource and update cloud_url + if (values.cloud_id) { + const selectedResource = accessibleResources.find( + (resource) => resource.id === values.cloud_id + ); + if (selectedResource) { + // Update multiple fields together ... somehow setting them in sequence + // doesn't work with the validator + // it may also be possible to await each setFieldValue call. + // https://github.com/jaredpalmer/formik/issues/2266 + setValues((prevValues) => ({ + ...prevValues, + cloud_name: selectedResource.name, + cloud_url: selectedResource.url, + })); + } + } + }, [values.cloud_id, accessibleResources, setFieldValue]); + + // This component doesn't render anything visible: + return null; +} + +export default function OAuthFinalizePage() { + const router = useRouter(); + const searchParams = useSearchParams(); + + const [statusMessage, setStatusMessage] = useState("Processing..."); + const [statusDetails, setStatusDetails] = useState( + "Please wait while we complete the setup." + ); + const [redirectUrl, setRedirectUrl] = useState(null); + const [isError, setIsError] = useState(false); + const [isSubmitted, setIsSubmitted] = useState(false); // New state + const [pageTitle, setPageTitle] = useState( + "Finalize Authorization with Third-Party service" + ); + + const [accessibleResources, setAccessibleResources] = useState< + ConfluenceAccessibleResource[] + >([]); + + // Extract query parameters + const credentialParam = searchParams.get("credential"); + const credential = credentialParam ? parseInt(credentialParam, 10) : NaN; + const pathname = usePathname(); + const connector = pathname?.split("/")[3]; + + useEffect(() => { + const onFirstLoad = async () => { + // Examples + // connector (url segment)= "google-drive" + // sourceType (for looking up metadata) = "google_drive" + + if (isNaN(credential)) { + setStatusMessage("Improperly formed OAuth finalization request."); + setStatusDetails("Invalid or missing credential id."); + setIsError(true); + return; + } + + const sourceType = connector.replaceAll("-", "_"); + if (!isValidSource(sourceType)) { + setStatusMessage( + `The specified connector source type ${sourceType} does not exist.` + ); + setStatusDetails(`${sourceType} is not a valid source type.`); + setIsError(true); + return; + } + + const sourceMetadata = getSourceMetadata(sourceType as ValidSources); + setPageTitle(`Finalize Authorization with ${sourceMetadata.displayName}`); + + setStatusMessage("Processing..."); + setStatusDetails( + "Please wait while we retrieve a list of your accessible sites." + ); + setIsError(false); // Ensure no error state during loading + + try { + const response = await handleOAuthPrepareFinalization( + connector, + credential + ); + + if (!response) { + throw new Error("Empty response from OAuth server."); + } + + setAccessibleResources(response.accessible_resources); + + setStatusMessage("Select a Confluence site"); + setStatusDetails(""); + + setIsError(false); + } catch (error) { + console.error("OAuth finalization error:", error); + setStatusMessage("Oops, something went wrong!"); + setStatusDetails( + "An error occurred during the OAuth finalization process. Please try again." + ); + setIsError(true); + } + }; + + onFirstLoad(); + }, [credential, connector]); + + useEffect(() => {}, [redirectUrl]); + + return ( +
+ } /> + +
+ +

{statusMessage}

+

{statusDetails}

+ + { + formikHelpers.setSubmitting(true); + try { + if (!values.cloud_id) { + throw new Error("Cloud ID is required."); + } + + if (!values.cloud_name) { + throw new Error("Cloud URL is required."); + } + + if (!values.cloud_url) { + throw new Error("Cloud URL is required."); + } + + const response = await handleOAuthConfluenceFinalize( + values.credential_id, + values.cloud_id, + values.cloud_name, + values.cloud_url + ); + formikHelpers.setSubmitting(false); + + if (response) { + setRedirectUrl(response.redirect_url); + setStatusMessage("Confluence authorization finalized."); + } + + setIsSubmitted(true); // Mark as submitted + } catch (error) { + console.error(error); + setStatusMessage("Error during submission."); + setStatusDetails( + "An error occurred during the submission process. Please try again." + ); + setIsError(true); + formikHelpers.setSubmitting(false); + } + }} + > + {({ isSubmitting, isValid, setFieldValue }) => ( +
+ {/* Debug info +
+
+                    isValid: {String(isValid)}
+                    errors: {JSON.stringify(errors, null, 2)}
+                    values: {JSON.stringify(values, null, 2)}
+                  
+
*/} + + {/* Our helper component that reacts to changes in cloud_id */} + + + + + + + + + {!redirectUrl && accessibleResources.length > 0 && ( + ({ + name: `${resource.name} - ${resource.url}`, + value: resource.id, + }))} + onSelect={(selectedValue) => { + const selectedResource = accessibleResources.find( + (resource) => resource.id === selectedValue + ); + if (selectedResource) { + setFieldValue("cloud_id", selectedResource.id); + } + }} + /> + )} +
+ {!redirectUrl && ( + + )} + + )} +
+ + {redirectUrl && !isError && ( +
+

+ Authorization finalized. Click{" "} + + here + {" "} + to continue. +

+
+ )} +
+
+
+ ); +} diff --git a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx index 53a91a38f..7268ae0ca 100644 --- a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx @@ -1,4 +1,10 @@ -import React, { Dispatch, FC, SetStateAction, useState } from "react"; +import React, { + Dispatch, + FC, + SetStateAction, + useEffect, + useState, +} from "react"; import CredentialSubText from "@/components/credentials/CredentialFields"; import { ConnectionConfiguration } from "@/lib/connectors/connectors"; import { TextFormField } from "@/components/admin/connectors/Field"; @@ -8,6 +14,7 @@ import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTyp import { ConfigurableSources } from "@/lib/types"; import { Credential } from "@/lib/connectors/credentials"; import { RenderField } from "./FieldRendering"; +import { useFormikContext } from "formik"; export interface DynamicConnectionFormProps { config: ConnectionConfiguration; @@ -22,7 +29,25 @@ const DynamicConnectionForm: FC = ({ connector, currentCredential, }) => { + const { setFieldValue } = useFormikContext(); // Get Formik's context functions + const [showAdvancedOptions, setShowAdvancedOptions] = useState(false); + const [connectorNameInitialized, setConnectorNameInitialized] = + useState(false); + + let initialConnectorName = ""; + if (config.initialConnectorName) { + initialConnectorName = + currentCredential?.credential_json?.[config.initialConnectorName] ?? ""; + } + + useEffect(() => { + const field_value = values["name"]; + if (initialConnectorName && !connectorNameInitialized && !field_value) { + setFieldValue("name", initialConnectorName); + setConnectorNameInitialized(true); + } + }, [initialConnectorName, setFieldValue, values]); return ( <> diff --git a/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx b/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx index 41cb446aa..a7577f662 100644 --- a/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx @@ -1,4 +1,4 @@ -import React, { Dispatch, FC, SetStateAction } from "react"; +import React, { Dispatch, FC, SetStateAction, useEffect } from "react"; import { AdminBooleanFormField } from "@/components/credentials/CredentialFields"; import { FileUpload } from "@/components/admin/connectors/FileUpload"; import { TabOption } from "@/lib/connectors/connectors"; @@ -16,6 +16,7 @@ import { TabsList, TabsTrigger, } from "@/components/ui/fully_wrapped_tabs"; +import { useFormikContext } from "formik"; interface TabsFieldProps { tabField: TabOption; @@ -123,6 +124,8 @@ export const RenderField: FC = ({ connector, currentCredential, }) => { + const { setFieldValue } = useFormikContext(); // Get Formik's context functions + const label = typeof field.label === "function" ? field.label(currentCredential) @@ -131,6 +134,22 @@ export const RenderField: FC = ({ typeof field.description === "function" ? field.description(currentCredential) : field.description; + const disabled = + typeof field.disabled === "function" + ? field.disabled(currentCredential) + : (field.disabled ?? false); + const initialValue = + typeof field.initial === "function" + ? field.initial(currentCredential) + : (field.initial ?? ""); + + // if initialValue exists, prepopulate the field with it + useEffect(() => { + const field_value = values[field.name]; + if (initialValue && field_value === undefined) { + setFieldValue(field.name, initialValue); + } + }, [field.name, initialValue, setFieldValue, values]); if (field.type === "tab") { return ( @@ -176,6 +195,8 @@ export const RenderField: FC = ({ subtext={description} name={field.name} label={label} + disabled={disabled} + onChange={(e) => setFieldValue(field.name, e.target.value)} /> ) : field.type === "text" ? ( = ({ name={field.name} isTextArea={field.isTextArea || false} defaultHeight={"h-15"} + disabled={disabled} + onChange={(e) => setFieldValue(field.name, e.target.value)} /> ) : field.type === "string_tab" ? (
{description}
diff --git a/web/src/components/admin/connectors/Field.tsx b/web/src/components/admin/connectors/Field.tsx index 851e3b380..15480ab61 100644 --- a/web/src/components/admin/connectors/Field.tsx +++ b/web/src/components/admin/connectors/Field.tsx @@ -132,7 +132,6 @@ export function TextFormField({ label, subtext, placeholder, - value, type = "text", optional, includeRevert, @@ -157,7 +156,6 @@ export function TextFormField({ vertical, className, }: { - value?: string; // Escape hatch for setting the value of the field - conflicts with Formik name: string; removeLabel?: boolean; label: string; @@ -253,7 +251,6 @@ export function TextFormField({ min={min} as={isTextArea ? "textarea" : "input"} type={type} - defaultValue={value} data-testid={name} name={name} id={name} diff --git a/web/src/components/credentials/CredentialFields.tsx b/web/src/components/credentials/CredentialFields.tsx index d23b30834..70e6f8809 100644 --- a/web/src/components/credentials/CredentialFields.tsx +++ b/web/src/components/credentials/CredentialFields.tsx @@ -130,6 +130,7 @@ interface BooleanFormFieldProps { small?: boolean; alignTop?: boolean; noLabel?: boolean; + disabled?: boolean; onChange?: (e: React.ChangeEvent) => void; } @@ -141,6 +142,7 @@ export const AdminBooleanFormField = ({ small, checked, alignTop, + disabled = false, onChange, }: BooleanFormFieldProps) => { const [field, meta, helpers] = useField(name); @@ -152,6 +154,7 @@ export const AdminBooleanFormField = ({ type="checkbox" {...field} checked={Boolean(field.value)} + disabled={disabled} onChange={(e) => { helpers.setValue(e.target.checked); }} diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index 37cb373b9..99ccde947 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -43,6 +43,7 @@ export interface Option { currentCredential: Credential | null ) => boolean; wrapInCollapsible?: boolean; + disabled?: boolean | ((currentCredential: Credential | null) => boolean); } export interface SelectOption extends Option { @@ -60,6 +61,7 @@ export interface ListOption extends Option { export interface TextOption extends Option { type: "text"; default?: string; + initial?: string | ((currentCredential: Credential | null) => string); isTextArea?: boolean; } @@ -105,6 +107,7 @@ export interface TabOption extends Option { export interface ConnectionConfiguration { description: string; subtext?: string; + initialConnectorName?: string; // a key in the credential to prepopulate the connector name field values: ( | BooleanOption | ListOption @@ -389,6 +392,7 @@ export const connectorConfigs: Record< }, confluence: { description: "Configure Confluence connector", + initialConnectorName: "cloud_name", values: [ { type: "checkbox", @@ -399,6 +403,12 @@ export const connectorConfigs: Record< default: true, description: "Check if this is a Confluence Cloud instance, uncheck for Confluence Server/Data Center", + disabled: (currentCredential) => { + if (currentCredential?.credential_json?.confluence_refresh_token) { + return true; + } + return false; + }, }, { type: "text", @@ -406,6 +416,15 @@ export const connectorConfigs: Record< label: "Wiki Base URL", name: "wiki_base", optional: false, + initial: (currentCredential) => { + return currentCredential?.credential_json?.wiki_base ?? ""; + }, + disabled: (currentCredential) => { + if (currentCredential?.credential_json?.confluence_refresh_token) { + return true; + } + return false; + }, description: "The base URL of your Confluence instance (e.g., https://your-domain.atlassian.net/wiki)", }, diff --git a/web/src/lib/connectors/oauth.ts b/web/src/lib/connectors/oauth.ts index d3472ccba..7100f2609 100644 --- a/web/src/lib/connectors/oauth.ts +++ b/web/src/lib/connectors/oauth.ts @@ -27,6 +27,9 @@ export async function getConnectorOauthRedirectUrl( export function useOAuthDetails(sourceType: ValidSources) { return useSWR( `/api/connector/oauth/details/${sourceType}`, - errorHandlingFetcher + errorHandlingFetcher, + { + shouldRetryOnError: false, + } ); } diff --git a/web/src/lib/oauth_utils.ts b/web/src/lib/oauth_utils.ts index 5f4329a15..db3342efe 100644 --- a/web/src/lib/oauth_utils.ts +++ b/web/src/lib/oauth_utils.ts @@ -1,5 +1,7 @@ import { - OAuthGoogleDriveCallbackResponse, + OAuthBaseCallbackResponse, + OAuthConfluenceFinalizeResponse, + OAuthConfluencePrepareFinalizationResponse, OAuthPrepareAuthorizationResponse, OAuthSlackCallbackResponse, } from "./types"; @@ -53,6 +55,10 @@ export async function handleOAuthAuthorizationResponse( return handleOAuthGoogleDriveAuthorizationResponse(code, state); } + if (connector === "confluence") { + return handleOAuthConfluenceAuthorizationResponse(code, state); + } + return; } @@ -75,7 +81,7 @@ export async function handleOAuthSlackAuthorizationResponse( }); if (!response.ok) { - let errorDetails = `Failed to handle OAuth authorization response: ${response.status}`; + let errorDetails = `Failed to handle OAuth Slack authorization response: ${response.status}`; try { const responseBody = await response.text(); // Read the body as text @@ -96,12 +102,10 @@ export async function handleOAuthSlackAuthorizationResponse( return data; } -// server side handler to process the oauth redirect callback -// https://api.slack.com/authentication/oauth-v2#exchanging export async function handleOAuthGoogleDriveAuthorizationResponse( code: string, state: string -): Promise { +): Promise { const url = `/api/oauth/connector/google-drive/callback?code=${encodeURIComponent( code )}&state=${encodeURIComponent(state)}`; @@ -115,7 +119,7 @@ export async function handleOAuthGoogleDriveAuthorizationResponse( }); if (!response.ok) { - let errorDetails = `Failed to handle OAuth authorization response: ${response.status}`; + let errorDetails = `Failed to handle OAuth Google Drive authorization response: ${response.status}`; try { const responseBody = await response.text(); // Read the body as text @@ -132,6 +136,137 @@ export async function handleOAuthGoogleDriveAuthorizationResponse( } // Parse the JSON response - const data = (await response.json()) as OAuthGoogleDriveCallbackResponse; + const data = (await response.json()) as OAuthBaseCallbackResponse; + return data; +} + +// call server side helper +// https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps +export async function handleOAuthConfluenceAuthorizationResponse( + code: string, + state: string +): Promise { + const url = `/api/oauth/connector/confluence/callback?code=${encodeURIComponent( + code + )}&state=${encodeURIComponent(state)}`; + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ code, state }), + }); + + if (!response.ok) { + let errorDetails = `Failed to handle OAuth Confluence authorization response: ${response.status}`; + + try { + const responseBody = await response.text(); // Read the body as text + errorDetails += `\nResponse Body: ${responseBody}`; + } catch (err) { + if (err instanceof Error) { + errorDetails += `\nUnable to read response body: ${err.message}`; + } else { + errorDetails += `\nUnable to read response body: Unknown error type`; + } + } + + throw new Error(errorDetails); + } + + // Parse the JSON response + const data = (await response.json()) as OAuthBaseCallbackResponse; + return data; +} + +export async function handleOAuthPrepareFinalization( + connector: string, + credential: number +) { + if (connector === "confluence") { + return handleOAuthConfluencePrepareFinalization(credential); + } + + return; +} + +// call server side helper +// https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps +export async function handleOAuthConfluencePrepareFinalization( + credential: number +): Promise { + const url = `/api/oauth/connector/confluence/accessible-resources?credential_id=${encodeURIComponent( + credential + )}`; + + const response = await fetch(url, { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + let errorDetails = `Failed to handle OAuth Confluence prepare finalization response: ${response.status}`; + + try { + const responseBody = await response.text(); // Read the body as text + errorDetails += `\nResponse Body: ${responseBody}`; + } catch (err) { + if (err instanceof Error) { + errorDetails += `\nUnable to read response body: ${err.message}`; + } else { + errorDetails += `\nUnable to read response body: Unknown error type`; + } + } + + throw new Error(errorDetails); + } + + // Parse the JSON response + const data = + (await response.json()) as OAuthConfluencePrepareFinalizationResponse; + return data; +} + +export async function handleOAuthConfluenceFinalize( + credential_id: number, + cloud_id: string, + cloud_name: string, + cloud_url: string +): Promise { + const url = `/api/oauth/connector/confluence/finalize?credential_id=${encodeURIComponent( + credential_id + )}&cloud_id=${encodeURIComponent(cloud_id)}&cloud_name=${encodeURIComponent( + cloud_name + )}&cloud_url=${encodeURIComponent(cloud_url)}`; + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + let errorDetails = `Failed to handle OAuth Confluence finalization response: ${response.status}`; + + try { + const responseBody = await response.text(); // Read the body as text + errorDetails += `\nResponse Body: ${responseBody}`; + } catch (err) { + if (err instanceof Error) { + errorDetails += `\nUnable to read response body: ${err.message}`; + } else { + errorDetails += `\nUnable to read response body: Unknown error type`; + } + } + + throw new Error(errorDetails); + } + + // Parse the JSON response + const data = (await response.json()) as OAuthConfluenceFinalizeResponse; return data; } diff --git a/web/src/lib/sources.ts b/web/src/lib/sources.ts index 846918986..7a3341256 100644 --- a/web/src/lib/sources.ts +++ b/web/src/lib/sources.ts @@ -120,6 +120,7 @@ export const SOURCE_METADATA_MAP: SourceMap = { displayName: "Confluence", category: SourceCategory.Wiki, docs: "https://docs.onyx.app/connectors/confluence", + oauthSupported: true, }, jira: { icon: JiraIcon, diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 7288fe905..70476663c 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -167,18 +167,36 @@ export interface OAuthPrepareAuthorizationResponse { url: string; } -export interface OAuthSlackCallbackResponse { +export interface OAuthBaseCallbackResponse { success: boolean; message: string; - team_id: string; - authed_user_id: string; + finalize_url: string | null; redirect_on_success: string; } -export interface OAuthGoogleDriveCallbackResponse { +export interface OAuthSlackCallbackResponse extends OAuthBaseCallbackResponse { + team_id: string; + authed_user_id: string; +} + +export interface ConfluenceAccessibleResource { + id: string; + name: string; + url: string; + scopes: string[]; + avatarUrl: string; +} + +export interface OAuthConfluencePrepareFinalizationResponse { success: boolean; message: string; - redirect_on_success: string; + accessible_resources: ConfluenceAccessibleResource[]; +} + +export interface OAuthConfluenceFinalizeResponse { + success: boolean; + message: string; + redirect_url: string; } export interface CCPairBasicInfo { @@ -382,6 +400,7 @@ export const oauthSupportedSources: ConfigurableSources[] = [ ValidSources.Slack, // NOTE: temporarily disabled until our GDrive App is approved // ValidSources.GoogleDrive, + ValidSources.Confluence, ]; export type OAuthSupportedSource = (typeof oauthSupportedSources)[number];