Feature/confluence oauth (#3477)

* first cut at slack oauth flow

* fix usage of hooks

* fix button spacing

* add additional error logging

* no dev redirect

* early cut at google drive oauth

* second pass

* switch to production uri's

* try handling oauth_interactive differently

* pass through client id and secret if uploaded

* fix call

* fix test

* temporarily disable check for testing

* Revert "temporarily disable check for testing"

This reverts commit 4b5a022a5fe38b05355a561616068af8e969def2.

* support visibility in test

* missed file

* first cut at confluence oauth

* work in progress

* work in progress

* work in progress

* work in progress

* work in progress

* first cut at distributed locking

* WIP to make test work

* add some dev mode affordances and gate usage of redis behind dynamic credentials

* mypy and credentials provider fixes

* WIP

* fix created at

* fix setting initialValue on everything

* remove debugging, fix ??? some TextFormField issues

* npm fixes

* comment cleanup

* fix comments

* pin the size of the card section

* more review fixes

* more fixes

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2025-02-27 19:48:51 -08:00 committed by GitHub
parent cd84b65011
commit 909403a648
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 2531 additions and 1119 deletions

View File

@ -59,10 +59,14 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
)
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
)
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""

View File

@ -9,12 +9,16 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@ -342,7 +346,8 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@ -354,7 +359,11 @@ def confluence_doc_sync(
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
confluence_connector.load_credentials(cc_pair.credential.credential_json)
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), "confluence", cc_pair.credential_id
)
confluence_connector.set_credentials_provider(provider)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)

View File

@ -1,9 +1,11 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.background.error_logging import emit_background_error
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@ -61,13 +63,27 @@ def _build_group_member_email_map(
def confluence_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
confluence_client = build_confluence_client(
credentials=cc_pair.credential.credential_json,
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
)
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
url = wiki_base.rstrip("/")
probe_kwargs = {
"max_backoff_retries": 6,
"max_backoff_seconds": 10,
}
final_kwargs = {
"max_backoff_retries": 10,
"max_backoff_seconds": 60,
}
confluence_client = OnyxConfluence(is_cloud, url, provider)
confluence_client._probe_connection(**probe_kwargs)
confluence_client._initialize_connection(**final_kwargs)
group_member_email_map = _build_group_member_email_map(
confluence_client=confluence_client,

View File

@ -32,7 +32,8 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@ -145,7 +145,8 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@ -119,6 +119,7 @@ def _build_onyx_groups(
def gdrive_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects

View File

@ -123,7 +123,8 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@ -28,6 +28,7 @@ DocSyncFuncType = Callable[
GroupSyncFuncType = Callable[
[
str,
ConnectorCredentialPair,
],
list[ExternalUserGroup],

View File

@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth import router as oauth_router
from ee.onyx.server.oauth.api import router as oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@ -152,4 +152,8 @@ def get_application() -> FastAPI:
# environment variable. Used to automate deployment for multiple environments.
seed_db()
# for debugging discovered routes
# for route in application.router.routes:
# print(f"Path: {route.path}, Methods: {route.methods}")
return application

View File

@ -1,629 +0,0 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.auth.users import current_user
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/oauth")
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.onyx.app/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class ConfluenceCloudOAuth:
"""work in progress"""
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.onyx.app/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
tenant_id = get_current_tenant_id()
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
# email=user.email, redirect_on_success=redirect_on_success
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": SlackOAuth.REDIRECT_URI,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"team_id": team_id,
"authed_user_id": authed_user_id,
"redirect_on_success": session.redirect_on_success,
}
)
# Work in progress
# @router.post("/connector/confluence/callback")
# def handle_confluence_oauth_callback(
# code: str,
# state: str,
# user: User = Depends(current_user),
# db_session: Session = Depends(get_session),
# tenant_id: str | None = Depends(get_current_tenant_id),
# ) -> JSONResponse:
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
# raise HTTPException(
# status_code=500,
# detail="Confluence client ID or client secret is not configured."
# )
# r = get_redis_client(tenant_id=tenant_id)
# # recover the state
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
# # Convert bytes back to a UUID
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
# oauth_uuid_str = str(oauth_uuid)
# r_key = f"da_oauth:{oauth_uuid_str}"
# result = r.get(r_key)
# if not result:
# raise HTTPException(
# status_code=400,
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
# )
# try:
# session = ConfluenceCloudOAuth.parse_session(result)
# # Exchange the authorization code for an access token
# response = requests.post(
# ConfluenceCloudOAuth.TOKEN_URL,
# headers={"Content-Type": "application/x-www-form-urlencoded"},
# data={
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
# "code": code,
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
# },
# )
# response_data = response.json()
# if not response_data.get("ok"):
# raise HTTPException(
# status_code=400,
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
# )
# # Extract token and team information
# access_token: str = response_data.get("access_token")
# team_id: str = response_data.get("team", {}).get("id")
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
# credential_info = CredentialBase(
# credential_json={"slack_bot_token": access_token},
# admin_public=True,
# source=DocumentSource.CONFLUENCE,
# name="Confluence OAuth",
# )
# logger.info(f"Slack access token: {access_token}")
# credential = create_credential(credential_info, user, db_session)
# logger.info(f"new_credential_id={credential.id}")
# except Exception as e:
# return JSONResponse(
# status_code=500,
# content={
# "success": False,
# "message": f"An error occurred during Slack OAuth: {str(e)}",
# },
# )
# finally:
# r.delete(r_key)
# # return the result
# return JSONResponse(
# content={
# "success": True,
# "message": "Slack OAuth completed successfully.",
# "team_id": team_id,
# "authed_user_id": authed_user_id,
# "redirect_on_success": session.redirect_on_success,
# }
# )
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
session: GoogleDriveOAuth.OAuthSession
try:
session = GoogleDriveOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"redirect_on_success": session.redirect_on_success,
}
)

View File

@ -0,0 +1,91 @@
import base64
import uuid
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from ee.onyx.server.oauth.api_router import router
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
from ee.onyx.server.oauth.slack import SlackOAuth
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import DocumentSource
from onyx.db.engine import get_current_tenant_id
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_admin_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str | None = None
if connector == DocumentSource.SLACK:
if not DEV_MODE:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.CONFLUENCE:
if not DEV_MODE:
oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state)
session = ConfluenceCloudOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
if not DEV_MODE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
if not session:
raise HTTPException(
status_code=500,
detail=f"The document source type {connector} failed to generate an OAuth session.",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})

View File

@ -0,0 +1,3 @@
from fastapi import APIRouter
router: APIRouter = APIRouter(prefix="/oauth")

View File

@ -0,0 +1,361 @@
import base64
import uuid
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
from onyx.db.credentials import create_credential
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import update_credential_json
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ConfluenceCloudOAuth:
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
class AccessibleResources(BaseModel):
id: str
name: str
url: str
scopes: list[str]
avatarUrl: str
CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL
ACCESSIBLE_RESOURCE_URL = (
"https://api.atlassian.com/oauth/token/accessible-resources"
)
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
# classic scope
"read:confluence-space.summary%20"
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence%20"
"search:confluence%20"
# granular scope
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
"offline_access"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&redirect_uri={redirect_uri}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json)
return session
@classmethod
def generate_finalize_url(cls, credential_id: int) -> str:
return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}"
@router.post("/connector/confluence/callback")
def confluence_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Handles the backend logic for the frontend page that the user is redirected to
after visiting the oauth authorization url."""
if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Confluence Cloud client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = ConfluenceCloudOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI
else:
redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
ConfluenceCloudOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": ConfluenceCloudOAuth.CLIENT_ID,
"client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
token_response: ConfluenceCloudOAuth.TokenResponse | None = None
try:
token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json(
response.text
)
except Exception:
raise RuntimeError(
"Confluence Cloud OAuth failed during code/token exchange."
)
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=token_response.expires_in)
credential_info = CredentialBase(
credential_json={
"confluence_access_token": token_response.access_token,
"confluence_refresh_token": token_response.refresh_token,
"created_at": now.isoformat(),
"expires_at": expires_at.isoformat(),
"expires_in": token_response.expires_in,
"scope": token_response.scope,
},
admin_public=True,
source=DocumentSource.CONFLUENCE,
name="Confluence Cloud OAuth",
)
credential = create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud OAuth completed successfully.",
"finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id),
"redirect_on_success": session.redirect_on_success,
}
)
@router.get("/connector/confluence/accessible-resources")
def confluence_oauth_accessible_resources(
credential_id: int,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Atlassian's API is weird and does not supply us with enough info to be in a
usable state after authorizing. All API's require a cloud id. We have to list
the accessible resources/sites and let the user choose which site to use."""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if not credential:
raise HTTPException(400, f"Credential {credential_id} not found.")
credential_dict = credential.credential_json
access_token = credential_dict["confluence_access_token"]
try:
# Exchange the authorization code for an access token
response = requests.get(
ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL,
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
},
)
response.raise_for_status()
accessible_resources_data = response.json()
# Validate the list of AccessibleResources
try:
accessible_resources = [
ConfluenceCloudOAuth.AccessibleResources(**resource)
for resource in accessible_resources_data
]
except ValidationError as e:
raise RuntimeError(f"Failed to parse accessible resources: {e}")
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}",
},
)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud get accessible resources completed successfully.",
"accessible_resources": [
resource.model_dump() for resource in accessible_resources
],
}
)
@router.post("/connector/confluence/finalize")
def confluence_oauth_finalize(
credential_id: int,
cloud_id: str,
cloud_name: str,
cloud_url: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Saves the info for the selected cloud site to the credential.
This is the final step in the confluence oauth flow where after the traditional
OAuth process, the user has to select a site to associate with the credentials.
After this, the credential is usable."""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if not credential:
raise HTTPException(
status_code=400,
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
)
new_credential_json: dict[str, Any] = dict(credential.credential_json)
new_credential_json["cloud_id"] = cloud_id
new_credential_json["cloud_name"] = cloud_name
new_credential_json["wiki_base"] = cloud_url
try:
update_credential_json(credential_id, new_credential_json, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
},
)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud OAuth finalized successfully.",
"redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence",
}
)

View File

@ -0,0 +1,229 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = GoogleDriveOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = GoogleDriveOAuth.REDIRECT_URI
else:
redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"finalize_url": None,
"redirect_on_success": session.redirect_on_success,
}
)

View File

@ -0,0 +1,197 @@
import base64
import uuid
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.danswer.dev/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = SlackOAuth.REDIRECT_URI
else:
redirect_uri = SlackOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"finalize_url": None,
"redirect_on_success": session.redirect_on_success,
"team_id": team_id,
"authed_user_id": authed_user_id,
}
)

View File

@ -423,7 +423,7 @@ def connector_external_group_sync_generator_task(
)
external_user_groups: list[ExternalUserGroup] = []
try:
external_user_groups = ext_group_sync_func(cc_pair)
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(

View File

@ -93,10 +93,11 @@ def _get_connector_runner(
runnable_connector.validate_connector_settings()
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
logger.exception("Unable to instantiate connector.")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed. Sometimes there are cases where the connector will
# it will never succeed
# Sometimes there are cases where the connector will
# intermittently fail to initialize in which case we should pass in
# leave_connector_active=True to allow it to continue.
# For example, if there is nightly maintenance on a Confluence Server instance,

View File

@ -11,17 +11,20 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import attachment_to_content
from onyx.connectors.confluence.onyx_confluence import (
extract_text_from_confluence_html,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import attachment_to_content
from onyx.connectors.confluence.utils import build_confluence_document_id
from onyx.connectors.confluence.utils import datetime_from_string
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@ -83,7 +86,9 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
)
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
class ConfluenceConnector(
LoadConnector, PollConnector, SlimConnector, CredentialsConnector
):
def __init__(
self,
wiki_base: str,
@ -102,7 +107,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self._confluence_client: OnyxConfluence | None = None
self.is_cloud = is_cloud
# Remove trailing slash from wiki_base if present
@ -137,6 +141,19 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
self.credentials_provider: CredentialsProviderInterface | None = None
self.probe_kwargs = {
"max_backoff_retries": 6,
"max_backoff_seconds": 10,
}
self.final_kwargs = {
"max_backoff_retries": 10,
"max_backoff_seconds": 60,
}
self._confluence_client: OnyxConfluence | None = None
@property
def confluence_client(self) -> OnyxConfluence:
@ -144,15 +161,22 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self._confluence_client = build_confluence_client(
credentials=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
self.credentials_provider = credentials_provider
# raises exception if there's a problem
confluence_client = OnyxConfluence(
self.is_cloud, self.wiki_base, credentials_provider
)
return None
confluence_client._probe_connection(**self.probe_kwargs)
confluence_client._initialize_connection(**self.final_kwargs)
self._confluence_client = confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError("Use set_credentials_provider with this connector.")
def _construct_page_query(
self,
@ -202,12 +226,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
return comment_string
def _convert_object_to_document(
self, confluence_object: dict[str, Any]
self,
confluence_object: dict[str, Any],
parent_content_id: str | None = None,
) -> Document | None:
"""
Takes in a confluence object, extracts all metadata, and converts it into a document.
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
parent_content_id: if the object is an attachment, specifies the content id that
the attachment is attached to
"""
# The url and the id are the same
object_url = build_confluence_document_id(
@ -226,7 +255,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
confluence_client=self.confluence_client, attachment=confluence_object
confluence_client=self.confluence_client,
attachment=confluence_object,
parent_content_id=parent_content_id,
)
if object_text is None:
@ -302,7 +333,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment)
doc = self._convert_object_to_document(attachment, confluence_page_id)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:

View File

@ -1,19 +1,37 @@
import math
import io
import json
import time
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypeVar
from urllib.parse import quote
import bs4
from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from redis import Redis
from requests import HTTPError
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -22,12 +40,14 @@ logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
class ConfluenceRateLimitError(Exception):
pass
@ -43,124 +63,349 @@ class ConfluenceUser(BaseModel):
type: str
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
class OnyxConfluence(Confluence):
class OnyxConfluence:
"""
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
This is a custom Confluence class that:
A. overrides the default Confluence class to add a custom CQL method.
B.
This is necessary because the default Confluence class does not properly support cql expansions.
All methods are automatically wrapped with handle_confluence_rate_limit.
"""
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
CREDENTIAL_PREFIX = "connector:confluence:credential"
CREDENTIAL_TTL = 300 # 5 min
def _wrap_methods(self) -> None:
def __init__(
self,
is_cloud: bool,
url: str,
credentials_provider: CredentialsProviderInterface,
) -> None:
self._is_cloud = is_cloud
self._url = url.rstrip("/")
self._credentials_provider = credentials_provider
self.redis_client: Redis | None = None
self.static_credentials: dict[str, Any] | None = None
if self._credentials_provider.is_dynamic():
self.redis_client = get_redis_client(
tenant_id=credentials_provider.get_tenant_id()
)
else:
self.static_credentials = self._credentials_provider.get_credentials()
self._confluence = Confluence(url)
self.credential_key: str = (
self.CREDENTIAL_PREFIX
+ f":credential_{self._credentials_provider.get_provider_key()}"
)
self._kwargs: Any = None
self.shared_base_kwargs = {
"api_version": "cloud" if is_cloud else "latest",
"backoff_and_retry": True,
"cloud": is_cloud,
}
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
"""credential_json - the current json credentials
Returns a tuple
1. The up to date credentials
2. True if the credentials were updated
This method is intended to be used within a distributed lock.
Lock, call this, update credentials if the tokens were refreshed, then release
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
wrap it with handle_confluence_rate_limit.
"""
for attr_name in dir(self):
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
setattr(
self,
attr_name,
handle_confluence_rate_limit(getattr(self, attr_name)),
# static credentials are preloaded, so no locking/redis required
if self.static_credentials:
return self.static_credentials, False
if not self.redis_client:
raise RuntimeError("self.redis_client is None")
# dynamic credentials need locking
# check redis first, then fallback to the DB
credential_raw = self.redis_client.get(self.credential_key)
if credential_raw is not None:
credential_bytes = cast(bytes, credential_raw)
credential_str = credential_bytes.decode("utf-8")
credential_json: dict[str, Any] = json.loads(credential_str)
else:
credential_json = self._credentials_provider.get_credentials()
if "confluence_refresh_token" not in credential_json:
# static credentials ... cache them permanently and return
self.static_credentials = credential_json
return credential_json, False
# check if we should refresh tokens. we're deciding to refresh halfway
# to expiration
now = datetime.now(timezone.utc)
created_at = datetime.fromisoformat(credential_json["created_at"])
expires_in: int = credential_json["expires_in"]
renew_at = created_at + timedelta(seconds=expires_in // 2)
if now <= renew_at:
# cached/current credentials are reasonably up to date
return credential_json, False
# we need to refresh
logger.info("Renewing Confluence Cloud credentials...")
new_credentials = confluence_refresh_tokens(
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID,
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET,
credential_json["cloud_id"],
credential_json["confluence_refresh_token"],
)
# store the new credentials to redis and to the db thru the provider
# redis: we use a 5 min TTL because we are given a 10 minute grace period
# when keys are rotated. it's easier to expire the cached credentials
# reasonably frequently rather than trying to handle strong synchronization
# between the db and redis everywhere the credentials might be updated
new_credential_str = json.dumps(new_credentials)
self.redis_client.set(
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
)
self._credentials_provider.set_credentials(new_credentials)
return new_credentials, True
@staticmethod
def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]:
oauth2_dict: dict[str, Any] = {}
if "confluence_refresh_token" in credentials:
oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
oauth2_dict["token"] = {}
oauth2_dict["token"]["access_token"] = credentials[
"confluence_access_token"
]
return oauth2_dict
def _probe_connection(
self,
**kwargs: Any,
) -> None:
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
with self._credentials_provider:
credentials, _ = self._renew_credentials()
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
logger.info("Probing Confluence with OAuth Access Token.")
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
credentials
)
url = (
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
)
confluence_client_with_minimal_retries = Confluence(
url=url, oauth2=oauth2_dict, **merged_kwargs
)
else:
logger.info("Probing Confluence with Personal Access Token.")
url = self._url
if self._is_cloud:
confluence_client_with_minimal_retries = Confluence(
url=url,
username=credentials["confluence_username"],
password=credentials["confluence_access_token"],
**merged_kwargs,
)
else:
confluence_client_with_minimal_retries = Confluence(
url=url,
token=credentials["confluence_access_token"],
**merged_kwargs,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
if not spaces:
raise RuntimeError(
f"No spaces found at {url}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
logger.info("Confluence probe succeeded.")
def _initialize_connection(
self,
**kwargs: Any,
) -> None:
"""Called externally to init the connection in a thread safe manner."""
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
with self._credentials_provider:
credentials, _ = self._renew_credentials()
self._confluence = self._initialize_connection_helper(
credentials, **merged_kwargs
)
self._kwargs = merged_kwargs
def _initialize_connection_helper(
self,
credentials: dict[str, Any],
**kwargs: Any,
) -> Confluence:
"""Called internally to init the connection. Distributed locking
to prevent multiple threads from modifying the credentials
must be handled around this function."""
confluence = None
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
logger.info("Connecting to Confluence Cloud with OAuth Access Token.")
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials)
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
else:
logger.info("Connecting to Confluence with Personal Access Token.")
if self._is_cloud:
confluence = Confluence(
url=self._url,
username=credentials["confluence_username"],
password=credentials["confluence_access_token"],
**kwargs,
)
else:
confluence = Confluence(
url=self._url,
token=credentials["confluence_access_token"],
**kwargs,
)
return confluence
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def _make_rate_limited_confluence_method(
self, name: str, credential_provider: CredentialsProviderInterface | None
) -> Callable[..., Any]:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
try:
if credential_provider:
with credential_provider:
credentials, renewed = self._renew_credentials()
if renewed:
self._confluence = self._initialize_connection_helper(
credentials, **self._kwargs
)
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
return attr(*args, **kwargs)
else:
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
return attr(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return wrapped_call
# def _wrap_methods(self) -> None:
# """
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
# wrap it with handle_confluence_rate_limit.
# """
# for attr_name in dir(self):
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
# setattr(
# self,
# attr_name,
# handle_confluence_rate_limit(getattr(self, attr_name)),
# )
# def _ensure_token_valid(self) -> None:
# if self._token_is_expired():
# self._refresh_token()
# # Re-init the Confluence client with the originally stored args
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
def __getattr__(self, name: str) -> Any:
"""Dynamically intercept attribute/method access."""
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
# If it's not a method, just return it after ensuring token validity
if not callable(attr):
return attr
# skip methods that start with "_"
if name.startswith("_"):
return attr
# wrap the method with our retry handler
rate_limited_method: Callable[
..., Any
] = self._make_rate_limited_confluence_method(name, self._credentials_provider)
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
return rate_limited_method(*args, **kwargs)
return wrapped_method
def _paginate_url(
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
@ -507,63 +752,212 @@ class OnyxConfluence(Confluence):
return response
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_with_minimal_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=6,
max_backoff_seconds=10,
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
logger.warning(f"failed to get confluence email for {user_name}")
# For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync.
email = None
# We may want to just return a string that indicates failure so we dont
# keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
user_id (str): The user id (i.e: the account-id or userkey)
confluence_client (Confluence): The Confluence Client
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
"""
global _USER_ID_TO_DISPLAY_NAME_CACHE
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
if not found_display_name:
try:
result = confluence_client.get_user_details_by_accountid(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
# why are we using session.get here? we probably won't retry these ... is that ok?
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
return extracted_text
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
body = confluence_object["body"]
object_html = body.get("storage", body.get("view", {})).get("value")
soup = bs4.BeautifulSoup(object_html, "html.parser")
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
if "ri:account-id" in user.attrs
else user.get("ri:userkey")
)
if not user_id:
logger.warning(
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
)
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
try:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
except Exception as e:
raise ConnectorValidationError(str(e))
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
cloud=is_cloud,
)
return format_document_soup(soup)

View File

@ -1,185 +1,38 @@
import io
import math
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import TypeVar
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urlparse
import bs4
import requests
from pydantic import BaseModel
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
pass
logger = setup_logger()
_USER_EMAIL_CACHE: dict[str, str | None] = {}
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
def get_user_email_from_username__server(
confluence_client: "OnyxConfluence", user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
logger.warning(f"failed to get confluence email for {user_name}")
# For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync.
email = None
# We may want to just return a string that indicates failure so we dont
# keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
user_id (str): The user id (i.e: the account-id or userkey)
confluence_client (Confluence): The Confluence Client
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
"""
global _USER_ID_TO_DISPLAY_NAME_CACHE
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
if not found_display_name:
try:
result = confluence_client.get_user_details_by_accountid(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def extract_text_from_confluence_html(
confluence_client: "OnyxConfluence",
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
body = confluence_object["body"]
object_html = body.get("storage", body.get("view", {})).get("value")
soup = bs4.BeautifulSoup(object_html, "html.parser")
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
if "ri:account-id" in user.attrs
else user.get("ri:userkey")
)
if not user_id:
logger.warning(
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
)
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
@ -193,49 +46,6 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
]
def attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def build_confluence_document_id(
base_url: str, content_url: str, is_cloud: bool
) -> str:
@ -284,6 +94,137 @@ def datetime_from_string(datetime_string: str) -> datetime:
return datetime_object
def confluence_refresh_tokens(
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
) -> dict[str, Any]:
# rotate the refresh and access token
# Note that access tokens are only good for an hour in confluence cloud,
# so we're going to have problems if the connector runs for longer
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair
response = requests.post(
CONFLUENCE_OAUTH_TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "refresh_token",
"client_id": client_id,
"client_secret": client_secret,
"refresh_token": refresh_token,
},
)
try:
token_response = TokenResponse.model_validate_json(response.text)
except Exception:
raise RuntimeError("Confluence Cloud token refresh failed.")
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=token_response.expires_in)
new_credentials: dict[str, Any] = {}
new_credentials["confluence_access_token"] = token_response.access_token
new_credentials["confluence_refresh_token"] = token_response.refresh_token
new_credentials["created_at"] = now.isoformat()
new_credentials["expires_at"] = expires_at.isoformat()
new_credentials["expires_in"] = token_response.expires_in
new_credentials["scope"] = token_response.scope
new_credentials["cloud_id"] = cloud_id
return new_credentials
F = TypeVar("F", bound=Callable[..., Any])
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except requests.HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
def get_single_param_from_url(url: str, param: str) -> str | None:
"""Get a parameter from a url"""
parsed_url = urlparse(url)

View File

@ -0,0 +1,135 @@
import uuid
from types import TracebackType
from typing import Any
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import Credential
from onyx.redis.redis_pool import get_redis_client
class OnyxDBCredentialsProvider(
CredentialsProviderInterface["OnyxDBCredentialsProvider"]
):
"""Implementation to allow the connector to callback and update credentials in the db.
Required in cases where credentials can rotate while the connector is running.
"""
LOCK_TTL = 900 # TTL of the lock
def __init__(self, tenant_id: str, connector_name: str, credential_id: int):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_id = credential_id
self.redis_client = get_redis_client(tenant_id=tenant_id)
# lock used to prevent overlapping renewal of credentials
self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}"
self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL)
def __enter__(self) -> "OnyxDBCredentialsProvider":
acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL)
if not acquired:
raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}")
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Release the lock when exiting the context."""
if self._lock and self._lock.owned():
self._lock.release()
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return str(self._credential_id)
def get_credentials(self) -> dict[str, Any]:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
credential = db_session.execute(
select(Credential).where(Credential.id == self._credential_id)
).scalar_one()
if credential is None:
raise ValueError(
f"No credential found: credential={self._credential_id}"
)
return credential.credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
try:
credential = db_session.execute(
select(Credential)
.where(Credential.id == self._credential_id)
.with_for_update()
).scalar_one()
if credential is None:
raise ValueError(
f"No credential found: credential={self._credential_id}"
)
credential.credential_json = credential_json
db_session.commit()
except Exception:
db_session.rollback()
raise
def is_dynamic(self) -> bool:
return True
class OnyxStaticCredentialsProvider(
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
):
"""Implementation (a very simple one!) to handle static credentials."""
def __init__(
self,
tenant_id: str | None,
connector_name: str,
credential_json: dict[str, Any],
):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_json = credential_json
self._provider_key = str(uuid.uuid4())
def __enter__(self) -> "OnyxStaticCredentialsProvider":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return self._provider_key
def get_credentials(self) -> dict[str, Any]:
return self._credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
self._credential_json = credential_json
def is_dynamic(self) -> bool:
return False

View File

@ -12,6 +12,7 @@ from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.bookstack.connector import BookstackConnector
from onyx.connectors.clickup.connector import ClickupConnector
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.discord.connector import DiscordConnector
from onyx.connectors.discourse.connector import DiscourseConnector
from onyx.connectors.document360.connector import Document360Connector
@ -32,6 +33,7 @@ from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@ -57,6 +59,7 @@ from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.models import Credential
from shared_configs.contextvars import get_current_tenant_id
class ConnectorMissingException(Exception):
@ -167,10 +170,17 @@ def instantiate_connector(
connector_class = identify_connector_class(source, input_type)
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json)
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
if isinstance(connector, CredentialsConnector):
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), str(source), credential.id
)
connector.set_credentials_provider(provider)
else:
new_credentials = connector.load_credentials(credential.credential_json)
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
return connector

View File

@ -1,7 +1,10 @@
import abc
from collections.abc import Generator
from collections.abc import Iterator
from types import TracebackType
from typing import Any
from typing import Generic
from typing import TypeVar
from pydantic import BaseModel
@ -111,6 +114,69 @@ class OAuthConnector(BaseConnector):
raise NotImplementedError
T = TypeVar("T", bound="CredentialsProviderInterface")
class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod
def __enter__(self) -> T:
raise NotImplementedError
@abc.abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
raise NotImplementedError
@abc.abstractmethod
def get_tenant_id(self) -> str | None:
raise NotImplementedError
@abc.abstractmethod
def get_provider_key(self) -> str:
"""a unique key that the connector can use to lock around a credential
that might be used simultaneously.
Will typically be the credential id, but can also just be something random
in cases when there is nothing to lock (aka static credentials)
"""
raise NotImplementedError
@abc.abstractmethod
def get_credentials(self) -> dict[str, Any]:
raise NotImplementedError
@abc.abstractmethod
def set_credentials(self, credential_json: dict[str, Any]) -> None:
raise NotImplementedError
@abc.abstractmethod
def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client
needs to use the locking features of the credentials provider to operate
correctly.
If static, the client can simply reference the credentials once and use them
through the entire indexing run.
"""
raise NotImplementedError
class CredentialsConnector(BaseConnector):
"""Implement this if the connector needs to be able to read and write credentials
on the fly. Typically used with shared credentials/tokens that might be renewed
at any time."""
@abc.abstractmethod
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
raise NotImplementedError
# Event driven
class EventConnector(BaseConnector):
@abc.abstractmethod

View File

@ -302,29 +302,29 @@ class WebConnector(LoadConnector):
playwright, context = start_playwright()
restart_playwright = False
while to_visit:
current_url = to_visit.pop()
if current_url in visited_links:
initial_url = to_visit.pop()
if initial_url in visited_links:
continue
visited_links.add(current_url)
visited_links.add(initial_url)
try:
protected_url_check(current_url)
protected_url_check(initial_url)
except Exception as e:
last_error = f"Invalid URL {current_url} due to {e}"
last_error = f"Invalid URL {initial_url} due to {e}"
logger.warning(last_error)
continue
logger.info(f"Visiting {current_url}")
logger.info(f"{len(visited_links)}: Visiting {initial_url}")
try:
check_internet_connection(current_url)
check_internet_connection(initial_url)
if restart_playwright:
playwright, context = start_playwright()
restart_playwright = False
if current_url.split(".")[-1] == "pdf":
if initial_url.split(".")[-1] == "pdf":
# PDF files are not checked for links
response = requests.get(current_url)
response = requests.get(initial_url)
page_text, metadata = read_pdf_file(
file=io.BytesIO(response.content)
)
@ -332,10 +332,10 @@ class WebConnector(LoadConnector):
doc_batch.append(
Document(
id=current_url,
sections=[Section(link=current_url, text=page_text)],
id=initial_url,
sections=[Section(link=initial_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=current_url.split("/")[-1],
semantic_identifier=initial_url.split("/")[-1],
metadata=metadata,
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
@ -347,21 +347,25 @@ class WebConnector(LoadConnector):
continue
page = context.new_page()
page_response = page.goto(current_url)
page_response = page.goto(initial_url)
last_modified = (
page_response.header_value("Last-Modified")
if page_response
else None
)
final_page = page.url
if final_page != current_url:
logger.info(f"Redirected to {final_page}")
protected_url_check(final_page)
current_url = final_page
if current_url in visited_links:
logger.info("Redirected page already indexed")
final_url = page.url
if final_url != initial_url:
protected_url_check(final_url)
initial_url = final_url
if initial_url in visited_links:
logger.info(
f"{len(visited_links)}: {initial_url} redirected to {final_url} - already indexed"
)
continue
visited_links.add(current_url)
logger.info(
f"{len(visited_links)}: {initial_url} redirected to {final_url}"
)
visited_links.add(initial_url)
if self.scroll_before_scraping:
scroll_attempts = 0
@ -379,13 +383,13 @@ class WebConnector(LoadConnector):
soup = BeautifulSoup(content, "html.parser")
if self.recursive:
internal_links = get_internal_links(base_url, current_url, soup)
internal_links = get_internal_links(base_url, initial_url, soup)
for link in internal_links:
if link not in visited_links:
to_visit.append(link)
if page_response and str(page_response.status)[0] in ("4", "5"):
last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response"
last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response"
logger.info(last_error)
continue
@ -393,12 +397,12 @@ class WebConnector(LoadConnector):
doc_batch.append(
Document(
id=current_url,
id=initial_url,
sections=[
Section(link=current_url, text=parsed_html.cleaned_text)
Section(link=initial_url, text=parsed_html.cleaned_text)
],
source=DocumentSource.WEB,
semantic_identifier=parsed_html.title or current_url,
semantic_identifier=parsed_html.title or initial_url,
metadata={},
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
@ -410,7 +414,7 @@ class WebConnector(LoadConnector):
page.close()
except Exception as e:
last_error = f"Failed to fetch '{current_url}': {e}"
last_error = f"Failed to fetch '{initial_url}': {e}"
logger.exception(last_error)
playwright.stop()
restart_playwright = True

View File

@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as oauth_router
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.folder.api import router as folder_router
from onyx.server.features.input_prompt.api import (
@ -323,7 +322,6 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
include_router_with_global_prefix_prepended(application, oauth_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@ -46,13 +46,21 @@ def mask_string(sensitive_str: str) -> str:
return "****...**" + sensitive_str[-4:]
MASK_CREDENTIALS_WHITELIST = {
DB_CREDENTIALS_AUTHENTICATION_METHOD,
"wiki_base",
"cloud_name",
"cloud_id",
}
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
masked_creds = {}
for key, val in credential_dict.items():
if isinstance(val, str):
# we want to pass the authentication_method field through so the frontend
# can disambiguate credentials created by different methods
if key == DB_CREDENTIALS_AUTHENTICATION_METHOD:
if key in MASK_CREDENTIALS_WHITELIST:
masked_creds[key] = val
else:
masked_creds[key] = mask_string(val)
@ -63,8 +71,8 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
continue
raise ValueError(
f"Unable to mask credentials of type other than string, cannot process request."
f"Recieved type: {type(val)}"
f"Unable to mask credentials of type other than string or int, cannot process request."
f"Received type: {type(val)}"
)
return masked_creds

View File

@ -5,7 +5,9 @@ from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
from onyx.connectors.models import Document
@ -18,12 +20,15 @@ def confluence_connector() -> ConfluenceConnector:
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
)
connector.load_credentials(
credentials_provider = OnyxStaticCredentialsProvider(
None,
DocumentSource.CONFLUENCE,
{
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
}
},
)
connector.set_credentials_provider(credentials_provider)
return connector

View File

@ -2,7 +2,9 @@ import os
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
@pytest.fixture
@ -11,12 +13,16 @@ def confluence_connector() -> ConfluenceConnector:
wiki_base="https://danswerai.atlassian.net",
is_cloud=True,
)
connector.load_credentials(
credentials_provider = OnyxStaticCredentialsProvider(
None,
DocumentSource.CONFLUENCE,
{
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
}
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
},
)
connector.set_credentials_provider(credentials_provider)
return connector

View File

@ -3,9 +3,7 @@ from unittest.mock import Mock
import pytest
from requests import HTTPError
from onyx.connectors.confluence.onyx_confluence import (
handle_confluence_rate_limit,
)
from onyx.connectors.confluence.utils import handle_confluence_rate_limit
@pytest.fixture
@ -50,6 +48,8 @@ def mock_confluence_call() -> Mock:
# mock_sleep.assert_called_with(int(retry_after))
# NOTE(rkuo): This tests an older version of rate limiting that is being deprecated
# and probably should go away soon.
def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
mock_confluence_call.side_effect = HTTPError(
response=Mock(status_code=500, text="Internal Server Error")

View File

@ -100,11 +100,6 @@ export default function LabelManagement() {
width="w-full max-w-xs"
name={`editLabelName_${label.id}`}
label="Label Name"
value={
values.editLabelId === label.id
? values.editLabelName
: label.name
}
onChange={(e) => {
setFieldValue("editLabelId", label.id);
setFieldValue("editLabelName", e.target.value);

View File

@ -52,7 +52,6 @@ export default function StarterMessagesList({
<TextFormField
name={`starter_messages.${index}.message`}
label=""
value={starterMessage.message}
onChange={(e) => handleInputChange(index, e.target.value)}
className="flex-grow"
removeLabel

View File

@ -193,13 +193,15 @@ export default function AddConnector({
// Check if there are no credentials
const noCredentials = credentialTemplate == null;
if (noCredentials && 1 != formStep) {
setFormStep(Math.max(1, formStep));
}
useEffect(() => {
if (noCredentials && 1 != formStep) {
setFormStep(Math.max(1, formStep));
}
if (!noCredentials && !credentialActivated && formStep != 0) {
setFormStep(Math.min(formStep, 0));
}
if (!noCredentials && !credentialActivated && formStep != 0) {
setFormStep(Math.min(formStep, 0));
}
}, [noCredentials, formStep, setFormStep]);
const convertStringToDateTime = (indexingStart: string | null) => {
return indexingStart ? new Date(indexingStart) : null;

View File

@ -33,7 +33,7 @@ export default function OAuthCallbackPage() {
const connector = pathname?.split("/")[3];
useEffect(() => {
const handleOAuthCallback = async () => {
const onFirstLoad = async () => {
// Examples
// connector (url segment)= "google-drive"
// sourceType (for looking up metadata) = "google_drive"
@ -85,10 +85,19 @@ export default function OAuthCallbackPage() {
}
setStatusMessage("Success!");
setStatusDetails(
`Your authorization with ${sourceMetadata.displayName} completed successfully.`
);
setRedirectUrl(response.redirect_on_success); // Extract the redirect URL
// set the continuation link
if (response.finalize_url) {
setRedirectUrl(response.finalize_url);
setStatusDetails(
`Your authorization with ${sourceMetadata.displayName} completed successfully. Additional steps are required to complete credential setup.`
);
} else {
setRedirectUrl(response.redirect_on_success);
setStatusDetails(
`Your authorization with ${sourceMetadata.displayName} completed successfully.`
);
}
setIsError(false);
} catch (error) {
console.error("OAuth error:", error);
@ -100,15 +109,15 @@ export default function OAuthCallbackPage() {
}
};
handleOAuthCallback();
onFirstLoad();
}, [code, state, connector]);
return (
<div className="container mx-auto py-8">
<div className="mx-auto h-screen flex flex-col">
<AdminPageTitle title={pageTitle} icon={<KeyIcon size={32} />} />
<div className="flex flex-col items-center justify-center min-h-screen">
<CardSection className="max-w-md">
<div className="flex-1 flex flex-col items-center justify-center">
<CardSection className="max-w-md w-[500px] h-[250px] p-8">
<h1 className="text-2xl font-bold mb-4">{statusMessage}</h1>
<p className="text-text-500">{statusDetails}</p>
{redirectUrl && !isError && (

View File

@ -0,0 +1,293 @@
"use client";
import { useEffect, useState } from "react";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import { AdminPageTitle } from "@/components/admin/Title";
import { Button } from "@/components/ui/button";
import Title from "@/components/ui/title";
import { KeyIcon } from "@/components/icons/icons";
import { getSourceMetadata, isValidSource } from "@/lib/sources";
import { ConfluenceAccessibleResource, ValidSources } from "@/lib/types";
import CardSection from "@/components/admin/CardSection";
import {
handleOAuthAuthorizationResponse,
handleOAuthConfluenceFinalize,
handleOAuthPrepareFinalization,
} from "@/lib/oauth_utils";
import { SelectorFormField } from "@/components/admin/connectors/Field";
import { ErrorMessage, Field, Form, Formik, useFormikContext } from "formik";
import * as Yup from "yup";
// Helper component to keep the effect logic clean:
function UpdateCloudURLOnCloudIdChange({
accessibleResources,
}: {
accessibleResources: ConfluenceAccessibleResource[];
}) {
const { values, setValues, setFieldValue } = useFormikContext<{
cloud_id: string;
cloud_name: string;
cloud_url: string;
}>();
useEffect(() => {
// Whenever cloud_id changes, find the matching resource and update cloud_url
if (values.cloud_id) {
const selectedResource = accessibleResources.find(
(resource) => resource.id === values.cloud_id
);
if (selectedResource) {
// Update multiple fields together ... somehow setting them in sequence
// doesn't work with the validator
// it may also be possible to await each setFieldValue call.
// https://github.com/jaredpalmer/formik/issues/2266
setValues((prevValues) => ({
...prevValues,
cloud_name: selectedResource.name,
cloud_url: selectedResource.url,
}));
}
}
}, [values.cloud_id, accessibleResources, setFieldValue]);
// This component doesn't render anything visible:
return null;
}
export default function OAuthFinalizePage() {
const router = useRouter();
const searchParams = useSearchParams();
const [statusMessage, setStatusMessage] = useState("Processing...");
const [statusDetails, setStatusDetails] = useState(
"Please wait while we complete the setup."
);
const [redirectUrl, setRedirectUrl] = useState<string | null>(null);
const [isError, setIsError] = useState(false);
const [isSubmitted, setIsSubmitted] = useState(false); // New state
const [pageTitle, setPageTitle] = useState(
"Finalize Authorization with Third-Party service"
);
const [accessibleResources, setAccessibleResources] = useState<
ConfluenceAccessibleResource[]
>([]);
// Extract query parameters
const credentialParam = searchParams.get("credential");
const credential = credentialParam ? parseInt(credentialParam, 10) : NaN;
const pathname = usePathname();
const connector = pathname?.split("/")[3];
useEffect(() => {
const onFirstLoad = async () => {
// Examples
// connector (url segment)= "google-drive"
// sourceType (for looking up metadata) = "google_drive"
if (isNaN(credential)) {
setStatusMessage("Improperly formed OAuth finalization request.");
setStatusDetails("Invalid or missing credential id.");
setIsError(true);
return;
}
const sourceType = connector.replaceAll("-", "_");
if (!isValidSource(sourceType)) {
setStatusMessage(
`The specified connector source type ${sourceType} does not exist.`
);
setStatusDetails(`${sourceType} is not a valid source type.`);
setIsError(true);
return;
}
const sourceMetadata = getSourceMetadata(sourceType as ValidSources);
setPageTitle(`Finalize Authorization with ${sourceMetadata.displayName}`);
setStatusMessage("Processing...");
setStatusDetails(
"Please wait while we retrieve a list of your accessible sites."
);
setIsError(false); // Ensure no error state during loading
try {
const response = await handleOAuthPrepareFinalization(
connector,
credential
);
if (!response) {
throw new Error("Empty response from OAuth server.");
}
setAccessibleResources(response.accessible_resources);
setStatusMessage("Select a Confluence site");
setStatusDetails("");
setIsError(false);
} catch (error) {
console.error("OAuth finalization error:", error);
setStatusMessage("Oops, something went wrong!");
setStatusDetails(
"An error occurred during the OAuth finalization process. Please try again."
);
setIsError(true);
}
};
onFirstLoad();
}, [credential, connector]);
useEffect(() => {}, [redirectUrl]);
return (
<div className="mx-auto h-screen flex flex-col">
<AdminPageTitle title={pageTitle} icon={<KeyIcon size={32} />} />
<div className="flex-1 flex flex-col items-center justify-center">
<CardSection className="max-w-md w-[500px] h-[250px] p-8">
<h1 className="text-2xl font-bold mb-4">{statusMessage}</h1>
<p className="text-text-500">{statusDetails}</p>
<Formik
initialValues={{
credential_id: credential,
cloud_id: "",
cloud_name: "",
cloud_url: "",
}}
validationSchema={Yup.object().shape({
credential_id: Yup.number().required(
"Credential ID is required."
),
cloud_id: Yup.string().required(
"You must select a Confluence site (id not found)."
),
cloud_name: Yup.string().required(
"You must select a Confluence site (name not found)."
),
cloud_url: Yup.string().required(
"You must select a Confluence site (url not found)."
),
})}
validateOnMount
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
try {
if (!values.cloud_id) {
throw new Error("Cloud ID is required.");
}
if (!values.cloud_name) {
throw new Error("Cloud URL is required.");
}
if (!values.cloud_url) {
throw new Error("Cloud URL is required.");
}
const response = await handleOAuthConfluenceFinalize(
values.credential_id,
values.cloud_id,
values.cloud_name,
values.cloud_url
);
formikHelpers.setSubmitting(false);
if (response) {
setRedirectUrl(response.redirect_url);
setStatusMessage("Confluence authorization finalized.");
}
setIsSubmitted(true); // Mark as submitted
} catch (error) {
console.error(error);
setStatusMessage("Error during submission.");
setStatusDetails(
"An error occurred during the submission process. Please try again."
);
setIsError(true);
formikHelpers.setSubmitting(false);
}
}}
>
{({ isSubmitting, isValid, setFieldValue }) => (
<Form>
{/* Debug info
<div className="mb-4 p-2 bg-gray-100 rounded text-xs">
<pre>
isValid: {String(isValid)}
errors: {JSON.stringify(errors, null, 2)}
values: {JSON.stringify(values, null, 2)}
</pre>
</div> */}
{/* Our helper component that reacts to changes in cloud_id */}
<UpdateCloudURLOnCloudIdChange
accessibleResources={accessibleResources}
/>
<Field type="hidden" name="cloud_name" />
<ErrorMessage
name="cloud_name"
component="div"
className="error"
/>
<Field type="hidden" name="cloud_url" />
<ErrorMessage
name="cloud_url"
component="div"
className="error"
/>
{!redirectUrl && accessibleResources.length > 0 && (
<SelectorFormField
name="cloud_id"
options={accessibleResources.map((resource) => ({
name: `${resource.name} - ${resource.url}`,
value: resource.id,
}))}
onSelect={(selectedValue) => {
const selectedResource = accessibleResources.find(
(resource) => resource.id === selectedValue
);
if (selectedResource) {
setFieldValue("cloud_id", selectedResource.id);
}
}}
/>
)}
<br />
{!redirectUrl && (
<Button
type="submit"
size="sm"
variant="submit"
disabled={!isValid || isSubmitting}
>
{isSubmitting ? "Submitting..." : "Submit"}
</Button>
)}
</Form>
)}
</Formik>
{redirectUrl && !isError && (
<div className="mt-4">
<p className="text-sm">
Authorization finalized. Click{" "}
<a href={redirectUrl} className="text-blue-500 underline">
here
</a>{" "}
to continue.
</p>
</div>
)}
</CardSection>
</div>
</div>
);
}

View File

@ -1,4 +1,10 @@
import React, { Dispatch, FC, SetStateAction, useState } from "react";
import React, {
Dispatch,
FC,
SetStateAction,
useEffect,
useState,
} from "react";
import CredentialSubText from "@/components/credentials/CredentialFields";
import { ConnectionConfiguration } from "@/lib/connectors/connectors";
import { TextFormField } from "@/components/admin/connectors/Field";
@ -8,6 +14,7 @@ import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTyp
import { ConfigurableSources } from "@/lib/types";
import { Credential } from "@/lib/connectors/credentials";
import { RenderField } from "./FieldRendering";
import { useFormikContext } from "formik";
export interface DynamicConnectionFormProps {
config: ConnectionConfiguration;
@ -22,7 +29,25 @@ const DynamicConnectionForm: FC<DynamicConnectionFormProps> = ({
connector,
currentCredential,
}) => {
const { setFieldValue } = useFormikContext<any>(); // Get Formik's context functions
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
const [connectorNameInitialized, setConnectorNameInitialized] =
useState(false);
let initialConnectorName = "";
if (config.initialConnectorName) {
initialConnectorName =
currentCredential?.credential_json?.[config.initialConnectorName] ?? "";
}
useEffect(() => {
const field_value = values["name"];
if (initialConnectorName && !connectorNameInitialized && !field_value) {
setFieldValue("name", initialConnectorName);
setConnectorNameInitialized(true);
}
}, [initialConnectorName, setFieldValue, values]);
return (
<>

View File

@ -1,4 +1,4 @@
import React, { Dispatch, FC, SetStateAction } from "react";
import React, { Dispatch, FC, SetStateAction, useEffect } from "react";
import { AdminBooleanFormField } from "@/components/credentials/CredentialFields";
import { FileUpload } from "@/components/admin/connectors/FileUpload";
import { TabOption } from "@/lib/connectors/connectors";
@ -16,6 +16,7 @@ import {
TabsList,
TabsTrigger,
} from "@/components/ui/fully_wrapped_tabs";
import { useFormikContext } from "formik";
interface TabsFieldProps {
tabField: TabOption;
@ -123,6 +124,8 @@ export const RenderField: FC<RenderFieldProps> = ({
connector,
currentCredential,
}) => {
const { setFieldValue } = useFormikContext<any>(); // Get Formik's context functions
const label =
typeof field.label === "function"
? field.label(currentCredential)
@ -131,6 +134,22 @@ export const RenderField: FC<RenderFieldProps> = ({
typeof field.description === "function"
? field.description(currentCredential)
: field.description;
const disabled =
typeof field.disabled === "function"
? field.disabled(currentCredential)
: (field.disabled ?? false);
const initialValue =
typeof field.initial === "function"
? field.initial(currentCredential)
: (field.initial ?? "");
// if initialValue exists, prepopulate the field with it
useEffect(() => {
const field_value = values[field.name];
if (initialValue && field_value === undefined) {
setFieldValue(field.name, initialValue);
}
}, [field.name, initialValue, setFieldValue, values]);
if (field.type === "tab") {
return (
@ -176,6 +195,8 @@ export const RenderField: FC<RenderFieldProps> = ({
subtext={description}
name={field.name}
label={label}
disabled={disabled}
onChange={(e) => setFieldValue(field.name, e.target.value)}
/>
) : field.type === "text" ? (
<TextFormField
@ -186,6 +207,8 @@ export const RenderField: FC<RenderFieldProps> = ({
name={field.name}
isTextArea={field.isTextArea || false}
defaultHeight={"h-15"}
disabled={disabled}
onChange={(e) => setFieldValue(field.name, e.target.value)}
/>
) : field.type === "string_tab" ? (
<div className="text-center">{description}</div>

View File

@ -132,7 +132,6 @@ export function TextFormField({
label,
subtext,
placeholder,
value,
type = "text",
optional,
includeRevert,
@ -157,7 +156,6 @@ export function TextFormField({
vertical,
className,
}: {
value?: string; // Escape hatch for setting the value of the field - conflicts with Formik
name: string;
removeLabel?: boolean;
label: string;
@ -253,7 +251,6 @@ export function TextFormField({
min={min}
as={isTextArea ? "textarea" : "input"}
type={type}
defaultValue={value}
data-testid={name}
name={name}
id={name}

View File

@ -130,6 +130,7 @@ interface BooleanFormFieldProps {
small?: boolean;
alignTop?: boolean;
noLabel?: boolean;
disabled?: boolean;
onChange?: (e: React.ChangeEvent<HTMLInputElement>) => void;
}
@ -141,6 +142,7 @@ export const AdminBooleanFormField = ({
small,
checked,
alignTop,
disabled = false,
onChange,
}: BooleanFormFieldProps) => {
const [field, meta, helpers] = useField(name);
@ -152,6 +154,7 @@ export const AdminBooleanFormField = ({
type="checkbox"
{...field}
checked={Boolean(field.value)}
disabled={disabled}
onChange={(e) => {
helpers.setValue(e.target.checked);
}}

View File

@ -43,6 +43,7 @@ export interface Option {
currentCredential: Credential<any> | null
) => boolean;
wrapInCollapsible?: boolean;
disabled?: boolean | ((currentCredential: Credential<any> | null) => boolean);
}
export interface SelectOption extends Option {
@ -60,6 +61,7 @@ export interface ListOption extends Option {
export interface TextOption extends Option {
type: "text";
default?: string;
initial?: string | ((currentCredential: Credential<any> | null) => string);
isTextArea?: boolean;
}
@ -105,6 +107,7 @@ export interface TabOption extends Option {
export interface ConnectionConfiguration {
description: string;
subtext?: string;
initialConnectorName?: string; // a key in the credential to prepopulate the connector name field
values: (
| BooleanOption
| ListOption
@ -389,6 +392,7 @@ export const connectorConfigs: Record<
},
confluence: {
description: "Configure Confluence connector",
initialConnectorName: "cloud_name",
values: [
{
type: "checkbox",
@ -399,6 +403,12 @@ export const connectorConfigs: Record<
default: true,
description:
"Check if this is a Confluence Cloud instance, uncheck for Confluence Server/Data Center",
disabled: (currentCredential) => {
if (currentCredential?.credential_json?.confluence_refresh_token) {
return true;
}
return false;
},
},
{
type: "text",
@ -406,6 +416,15 @@ export const connectorConfigs: Record<
label: "Wiki Base URL",
name: "wiki_base",
optional: false,
initial: (currentCredential) => {
return currentCredential?.credential_json?.wiki_base ?? "";
},
disabled: (currentCredential) => {
if (currentCredential?.credential_json?.confluence_refresh_token) {
return true;
}
return false;
},
description:
"The base URL of your Confluence instance (e.g., https://your-domain.atlassian.net/wiki)",
},

View File

@ -27,6 +27,9 @@ export async function getConnectorOauthRedirectUrl(
export function useOAuthDetails(sourceType: ValidSources) {
return useSWR<OAuthDetails>(
`/api/connector/oauth/details/${sourceType}`,
errorHandlingFetcher
errorHandlingFetcher,
{
shouldRetryOnError: false,
}
);
}

View File

@ -1,5 +1,7 @@
import {
OAuthGoogleDriveCallbackResponse,
OAuthBaseCallbackResponse,
OAuthConfluenceFinalizeResponse,
OAuthConfluencePrepareFinalizationResponse,
OAuthPrepareAuthorizationResponse,
OAuthSlackCallbackResponse,
} from "./types";
@ -53,6 +55,10 @@ export async function handleOAuthAuthorizationResponse(
return handleOAuthGoogleDriveAuthorizationResponse(code, state);
}
if (connector === "confluence") {
return handleOAuthConfluenceAuthorizationResponse(code, state);
}
return;
}
@ -75,7 +81,7 @@ export async function handleOAuthSlackAuthorizationResponse(
});
if (!response.ok) {
let errorDetails = `Failed to handle OAuth authorization response: ${response.status}`;
let errorDetails = `Failed to handle OAuth Slack authorization response: ${response.status}`;
try {
const responseBody = await response.text(); // Read the body as text
@ -96,12 +102,10 @@ export async function handleOAuthSlackAuthorizationResponse(
return data;
}
// server side handler to process the oauth redirect callback
// https://api.slack.com/authentication/oauth-v2#exchanging
export async function handleOAuthGoogleDriveAuthorizationResponse(
code: string,
state: string
): Promise<OAuthGoogleDriveCallbackResponse> {
): Promise<OAuthBaseCallbackResponse> {
const url = `/api/oauth/connector/google-drive/callback?code=${encodeURIComponent(
code
)}&state=${encodeURIComponent(state)}`;
@ -115,7 +119,7 @@ export async function handleOAuthGoogleDriveAuthorizationResponse(
});
if (!response.ok) {
let errorDetails = `Failed to handle OAuth authorization response: ${response.status}`;
let errorDetails = `Failed to handle OAuth Google Drive authorization response: ${response.status}`;
try {
const responseBody = await response.text(); // Read the body as text
@ -132,6 +136,137 @@ export async function handleOAuthGoogleDriveAuthorizationResponse(
}
// Parse the JSON response
const data = (await response.json()) as OAuthGoogleDriveCallbackResponse;
const data = (await response.json()) as OAuthBaseCallbackResponse;
return data;
}
// call server side helper
// https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps
export async function handleOAuthConfluenceAuthorizationResponse(
code: string,
state: string
): Promise<OAuthBaseCallbackResponse> {
const url = `/api/oauth/connector/confluence/callback?code=${encodeURIComponent(
code
)}&state=${encodeURIComponent(state)}`;
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ code, state }),
});
if (!response.ok) {
let errorDetails = `Failed to handle OAuth Confluence authorization response: ${response.status}`;
try {
const responseBody = await response.text(); // Read the body as text
errorDetails += `\nResponse Body: ${responseBody}`;
} catch (err) {
if (err instanceof Error) {
errorDetails += `\nUnable to read response body: ${err.message}`;
} else {
errorDetails += `\nUnable to read response body: Unknown error type`;
}
}
throw new Error(errorDetails);
}
// Parse the JSON response
const data = (await response.json()) as OAuthBaseCallbackResponse;
return data;
}
export async function handleOAuthPrepareFinalization(
connector: string,
credential: number
) {
if (connector === "confluence") {
return handleOAuthConfluencePrepareFinalization(credential);
}
return;
}
// call server side helper
// https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps
export async function handleOAuthConfluencePrepareFinalization(
credential: number
): Promise<OAuthConfluencePrepareFinalizationResponse> {
const url = `/api/oauth/connector/confluence/accessible-resources?credential_id=${encodeURIComponent(
credential
)}`;
const response = await fetch(url, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
let errorDetails = `Failed to handle OAuth Confluence prepare finalization response: ${response.status}`;
try {
const responseBody = await response.text(); // Read the body as text
errorDetails += `\nResponse Body: ${responseBody}`;
} catch (err) {
if (err instanceof Error) {
errorDetails += `\nUnable to read response body: ${err.message}`;
} else {
errorDetails += `\nUnable to read response body: Unknown error type`;
}
}
throw new Error(errorDetails);
}
// Parse the JSON response
const data =
(await response.json()) as OAuthConfluencePrepareFinalizationResponse;
return data;
}
export async function handleOAuthConfluenceFinalize(
credential_id: number,
cloud_id: string,
cloud_name: string,
cloud_url: string
): Promise<OAuthConfluenceFinalizeResponse> {
const url = `/api/oauth/connector/confluence/finalize?credential_id=${encodeURIComponent(
credential_id
)}&cloud_id=${encodeURIComponent(cloud_id)}&cloud_name=${encodeURIComponent(
cloud_name
)}&cloud_url=${encodeURIComponent(cloud_url)}`;
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
let errorDetails = `Failed to handle OAuth Confluence finalization response: ${response.status}`;
try {
const responseBody = await response.text(); // Read the body as text
errorDetails += `\nResponse Body: ${responseBody}`;
} catch (err) {
if (err instanceof Error) {
errorDetails += `\nUnable to read response body: ${err.message}`;
} else {
errorDetails += `\nUnable to read response body: Unknown error type`;
}
}
throw new Error(errorDetails);
}
// Parse the JSON response
const data = (await response.json()) as OAuthConfluenceFinalizeResponse;
return data;
}

View File

@ -120,6 +120,7 @@ export const SOURCE_METADATA_MAP: SourceMap = {
displayName: "Confluence",
category: SourceCategory.Wiki,
docs: "https://docs.onyx.app/connectors/confluence",
oauthSupported: true,
},
jira: {
icon: JiraIcon,

View File

@ -167,18 +167,36 @@ export interface OAuthPrepareAuthorizationResponse {
url: string;
}
export interface OAuthSlackCallbackResponse {
export interface OAuthBaseCallbackResponse {
success: boolean;
message: string;
team_id: string;
authed_user_id: string;
finalize_url: string | null;
redirect_on_success: string;
}
export interface OAuthGoogleDriveCallbackResponse {
export interface OAuthSlackCallbackResponse extends OAuthBaseCallbackResponse {
team_id: string;
authed_user_id: string;
}
export interface ConfluenceAccessibleResource {
id: string;
name: string;
url: string;
scopes: string[];
avatarUrl: string;
}
export interface OAuthConfluencePrepareFinalizationResponse {
success: boolean;
message: string;
redirect_on_success: string;
accessible_resources: ConfluenceAccessibleResource[];
}
export interface OAuthConfluenceFinalizeResponse {
success: boolean;
message: string;
redirect_url: string;
}
export interface CCPairBasicInfo {
@ -382,6 +400,7 @@ export const oauthSupportedSources: ConfigurableSources[] = [
ValidSources.Slack,
// NOTE: temporarily disabled until our GDrive App is approved
// ValidSources.GoogleDrive,
ValidSources.Confluence,
];
export type OAuthSupportedSource = (typeof oauthSupportedSources)[number];