mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 21:32:36 +01:00
Feature/confluence oauth (#3477)
* first cut at slack oauth flow * fix usage of hooks * fix button spacing * add additional error logging * no dev redirect * early cut at google drive oauth * second pass * switch to production uri's * try handling oauth_interactive differently * pass through client id and secret if uploaded * fix call * fix test * temporarily disable check for testing * Revert "temporarily disable check for testing" This reverts commit 4b5a022a5fe38b05355a561616068af8e969def2. * support visibility in test * missed file * first cut at confluence oauth * work in progress * work in progress * work in progress * work in progress * work in progress * first cut at distributed locking * WIP to make test work * add some dev mode affordances and gate usage of redis behind dynamic credentials * mypy and credentials provider fixes * WIP * fix created at * fix setting initialValue on everything * remove debugging, fix ??? some TextFormField issues * npm fixes * comment cleanup * fix comments * pin the size of the card section * more review fixes * more fixes --------- Co-authored-by: Richard Kuo <rkuo@rkuo.com> Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
parent
cd84b65011
commit
909403a648
@ -59,10 +59,14 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
|
||||
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
|
||||
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
|
||||
)
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
|
||||
)
|
||||
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
|
@ -9,12 +9,16 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -342,7 +346,8 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@ -354,7 +359,11 @@ def confluence_doc_sync(
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
confluence_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
provider = OnyxDBCredentialsProvider(
|
||||
get_current_tenant_id(), "confluence", cc_pair.credential_id
|
||||
)
|
||||
confluence_connector.set_credentials_provider(provider)
|
||||
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
|
@ -1,9 +1,11 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -61,13 +63,27 @@ def _build_group_member_email_map(
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
confluence_client = build_confluence_client(
|
||||
credentials=cc_pair.credential.credential_json,
|
||||
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
|
||||
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
|
||||
)
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
|
||||
url = wiki_base.rstrip("/")
|
||||
|
||||
probe_kwargs = {
|
||||
"max_backoff_retries": 6,
|
||||
"max_backoff_seconds": 10,
|
||||
}
|
||||
|
||||
final_kwargs = {
|
||||
"max_backoff_retries": 10,
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
confluence_client = OnyxConfluence(is_cloud, url, provider)
|
||||
confluence_client._probe_connection(**probe_kwargs)
|
||||
confluence_client._initialize_connection(**final_kwargs)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
|
@ -32,7 +32,8 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
@ -145,7 +145,8 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
@ -119,6 +119,7 @@ def _build_onyx_groups(
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
# Initialize connector and build credential/service objects
|
||||
|
@ -123,7 +123,8 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
@ -28,6 +28,7 @@ DocSyncFuncType = Callable[
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
str,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
list[ExternalUserGroup],
|
||||
|
@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
)
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.onyx.server.oauth import router as oauth_router
|
||||
from ee.onyx.server.oauth.api import router as oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
@ -152,4 +152,8 @@ def get_application() -> FastAPI:
|
||||
# environment variable. Used to automate deployment for multiple environments.
|
||||
seed_db()
|
||||
|
||||
# for debugging discovered routes
|
||||
# for route in application.router.routes:
|
||||
# print(f"Path: {route.path}, Methods: {route.methods}")
|
||||
|
||||
return application
|
||||
|
@ -1,629 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/oauth")
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_user),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
session: str
|
||||
|
||||
if connector == DocumentSource.SLACK:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
# email=user.email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": SlackOAuth.REDIRECT_URI,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Work in progress
|
||||
# @router.post("/connector/confluence/callback")
|
||||
# def handle_confluence_oauth_callback(
|
||||
# code: str,
|
||||
# state: str,
|
||||
# user: User = Depends(current_user),
|
||||
# db_session: Session = Depends(get_session),
|
||||
# tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
# ) -> JSONResponse:
|
||||
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
# raise HTTPException(
|
||||
# status_code=500,
|
||||
# detail="Confluence client ID or client secret is not configured."
|
||||
# )
|
||||
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# # recover the state
|
||||
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
|
||||
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
|
||||
|
||||
# # Convert bytes back to a UUID
|
||||
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
# oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
# result = r.get(r_key)
|
||||
# if not result:
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
|
||||
# )
|
||||
|
||||
# try:
|
||||
# session = ConfluenceCloudOAuth.parse_session(result)
|
||||
|
||||
# # Exchange the authorization code for an access token
|
||||
# response = requests.post(
|
||||
# ConfluenceCloudOAuth.TOKEN_URL,
|
||||
# headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
# data={
|
||||
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
# "code": code,
|
||||
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
|
||||
# },
|
||||
# )
|
||||
|
||||
# response_data = response.json()
|
||||
|
||||
# if not response_data.get("ok"):
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
|
||||
# )
|
||||
|
||||
# # Extract token and team information
|
||||
# access_token: str = response_data.get("access_token")
|
||||
# team_id: str = response_data.get("team", {}).get("id")
|
||||
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
# credential_info = CredentialBase(
|
||||
# credential_json={"slack_bot_token": access_token},
|
||||
# admin_public=True,
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# name="Confluence OAuth",
|
||||
# )
|
||||
|
||||
# logger.info(f"Slack access token: {access_token}")
|
||||
|
||||
# credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
# logger.info(f"new_credential_id={credential.id}")
|
||||
# except Exception as e:
|
||||
# return JSONResponse(
|
||||
# status_code=500,
|
||||
# content={
|
||||
# "success": False,
|
||||
# "message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
# },
|
||||
# )
|
||||
# finally:
|
||||
# r.delete(r_key)
|
||||
|
||||
# # return the result
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "success": True,
|
||||
# "message": "Slack OAuth completed successfully.",
|
||||
# "team_id": team_id,
|
||||
# "authed_user_id": authed_user_id,
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
session: GoogleDriveOAuth.OAuthSession
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
91
backend/ee/onyx/server/oauth/api.py
Normal file
91
backend/ee/onyx/server/oauth/api.py
Normal file
@ -0,0 +1,91 @@
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
|
||||
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
|
||||
from ee.onyx.server.oauth.slack import SlackOAuth
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_admin_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
|
||||
session: str | None = None
|
||||
if connector == DocumentSource.SLACK:
|
||||
if not DEV_MODE:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state)
|
||||
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.CONFLUENCE:
|
||||
if not DEV_MODE:
|
||||
oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
session = ConfluenceCloudOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
if not DEV_MODE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"The document source type {connector} failed to generate an OAuth session.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
3
backend/ee/onyx/server/oauth/api_router.py
Normal file
3
backend/ee/onyx/server/oauth/api_router.py
Normal file
@ -0,0 +1,3 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/oauth")
|
361
backend/ee/onyx/server/oauth/confluence_cloud.py
Normal file
361
backend/ee/onyx/server/oauth/confluence_cloud.py
Normal file
@ -0,0 +1,361 @@
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
class AccessibleResources(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
url: str
|
||||
scopes: list[str]
|
||||
avatarUrl: str
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL
|
||||
|
||||
ACCESSIBLE_RESOURCE_URL = (
|
||||
"https://api.atlassian.com/oauth/token/accessible-resources"
|
||||
)
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
# classic scope
|
||||
"read:confluence-space.summary%20"
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence%20"
|
||||
"search:confluence%20"
|
||||
# granular scope
|
||||
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
|
||||
"offline_access"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code
|
||||
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def generate_finalize_url(cls, credential_id: int) -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}"
|
||||
|
||||
|
||||
@router.post("/connector/confluence/callback")
|
||||
def confluence_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Handles the backend logic for the frontend page that the user is redirected to
|
||||
after visiting the oauth authorization url."""
|
||||
|
||||
if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Confluence Cloud client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = ConfluenceCloudOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
ConfluenceCloudOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
"client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
token_response: ConfluenceCloudOAuth.TokenResponse | None = None
|
||||
|
||||
try:
|
||||
token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json(
|
||||
response.text
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
"Confluence Cloud OAuth failed during code/token exchange."
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=token_response.expires_in)
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={
|
||||
"confluence_access_token": token_response.access_token,
|
||||
"confluence_refresh_token": token_response.refresh_token,
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"expires_in": token_response.expires_in,
|
||||
"scope": token_response.scope,
|
||||
},
|
||||
admin_public=True,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
name="Confluence Cloud OAuth",
|
||||
)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud OAuth completed successfully.",
|
||||
"finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id),
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/connector/confluence/accessible-resources")
|
||||
def confluence_oauth_accessible_resources(
|
||||
credential_id: int,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Atlassian's API is weird and does not supply us with enough info to be in a
|
||||
usable state after authorizing. All API's require a cloud id. We have to list
|
||||
the accessible resources/sites and let the user choose which site to use."""
|
||||
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if not credential:
|
||||
raise HTTPException(400, f"Credential {credential_id} not found.")
|
||||
|
||||
credential_dict = credential.credential_json
|
||||
access_token = credential_dict["confluence_access_token"]
|
||||
|
||||
try:
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.get(
|
||||
ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
accessible_resources_data = response.json()
|
||||
|
||||
# Validate the list of AccessibleResources
|
||||
try:
|
||||
accessible_resources = [
|
||||
ConfluenceCloudOAuth.AccessibleResources(**resource)
|
||||
for resource in accessible_resources_data
|
||||
]
|
||||
except ValidationError as e:
|
||||
raise RuntimeError(f"Failed to parse accessible resources: {e}")
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud get accessible resources completed successfully.",
|
||||
"accessible_resources": [
|
||||
resource.model_dump() for resource in accessible_resources
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/connector/confluence/finalize")
|
||||
def confluence_oauth_finalize(
|
||||
credential_id: int,
|
||||
cloud_id: str,
|
||||
cloud_name: str,
|
||||
cloud_url: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Saves the info for the selected cloud site to the credential.
|
||||
This is the final step in the confluence oauth flow where after the traditional
|
||||
OAuth process, the user has to select a site to associate with the credentials.
|
||||
After this, the credential is usable."""
|
||||
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
|
||||
)
|
||||
|
||||
new_credential_json: dict[str, Any] = dict(credential.credential_json)
|
||||
new_credential_json["cloud_id"] = cloud_id
|
||||
new_credential_json["cloud_name"] = cloud_name
|
||||
new_credential_json["wiki_base"] = cloud_url
|
||||
|
||||
try:
|
||||
update_credential_json(credential_id, new_credential_json, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud OAuth finalized successfully.",
|
||||
"redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence",
|
||||
}
|
||||
)
|
229
backend/ee/onyx/server/oauth/google_drive.py
Normal file
229
backend/ee/onyx/server/oauth/google_drive.py
Normal file
@ -0,0 +1,229 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = GoogleDriveOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"finalize_url": None,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
197
backend/ee/onyx/server/oauth/slack.py
Normal file
197
backend/ee/onyx/server/oauth/slack.py
Normal file
@ -0,0 +1,197 @@
|
||||
import base64
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = SlackOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = SlackOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"finalize_url": None,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
}
|
||||
)
|
@ -423,7 +423,7 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(cc_pair)
|
||||
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
update_connector_credential_pair(
|
||||
|
@ -93,10 +93,11 @@ def _get_connector_runner(
|
||||
runnable_connector.validate_connector_settings()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
||||
logger.exception("Unable to instantiate connector.")
|
||||
# since we failed to even instantiate the connector, we pause the CCPair since
|
||||
# it will never succeed. Sometimes there are cases where the connector will
|
||||
# it will never succeed
|
||||
|
||||
# Sometimes there are cases where the connector will
|
||||
# intermittently fail to initialize in which case we should pass in
|
||||
# leave_connector_active=True to allow it to continue.
|
||||
# For example, if there is nightly maintenance on a Confluence Server instance,
|
||||
|
@ -11,17 +11,20 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import attachment_to_content
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
extract_text_from_confluence_html,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import attachment_to_content
|
||||
from onyx.connectors.confluence.utils import build_confluence_document_id
|
||||
from onyx.connectors.confluence.utils import datetime_from_string
|
||||
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@ -83,7 +86,9 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
)
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class ConfluenceConnector(
|
||||
LoadConnector, PollConnector, SlimConnector, CredentialsConnector
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_base: str,
|
||||
@ -102,7 +107,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
@ -137,6 +141,19 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
self.credentials_provider: CredentialsProviderInterface | None = None
|
||||
|
||||
self.probe_kwargs = {
|
||||
"max_backoff_retries": 6,
|
||||
"max_backoff_seconds": 10,
|
||||
}
|
||||
|
||||
self.final_kwargs = {
|
||||
"max_backoff_retries": 10,
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
@ -144,15 +161,22 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
return self._confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
|
||||
# for a list of other hidden constructor args
|
||||
self._confluence_client = build_confluence_client(
|
||||
credentials=credentials,
|
||||
is_cloud=self.is_cloud,
|
||||
wiki_base=self.wiki_base,
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
self.credentials_provider = credentials_provider
|
||||
|
||||
# raises exception if there's a problem
|
||||
confluence_client = OnyxConfluence(
|
||||
self.is_cloud, self.wiki_base, credentials_provider
|
||||
)
|
||||
return None
|
||||
confluence_client._probe_connection(**self.probe_kwargs)
|
||||
confluence_client._initialize_connection(**self.final_kwargs)
|
||||
|
||||
self._confluence_client = confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||
|
||||
def _construct_page_query(
|
||||
self,
|
||||
@ -202,12 +226,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return comment_string
|
||||
|
||||
def _convert_object_to_document(
|
||||
self, confluence_object: dict[str, Any]
|
||||
self,
|
||||
confluence_object: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Takes in a confluence object, extracts all metadata, and converts it into a document.
|
||||
If its a page, it extracts the text, adds the comments for the document text.
|
||||
If its an attachment, it just downloads the attachment and converts that into a document.
|
||||
|
||||
parent_content_id: if the object is an attachment, specifies the content id that
|
||||
the attachment is attached to
|
||||
"""
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
@ -226,7 +255,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
|
||||
elif confluence_object["type"] == "attachment":
|
||||
object_text = attachment_to_content(
|
||||
confluence_client=self.confluence_client, attachment=confluence_object
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=confluence_object,
|
||||
parent_content_id=parent_content_id,
|
||||
)
|
||||
|
||||
if object_text is None:
|
||||
@ -302,7 +333,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
doc = self._convert_object_to_document(attachment, confluence_page_id)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
|
@ -1,19 +1,37 @@
|
||||
import math
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from requests import HTTPError
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -22,12 +40,14 @@ logger = setup_logger()
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
@ -43,124 +63,349 @@ class ConfluenceUser(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 50
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
class OnyxConfluence:
|
||||
"""
|
||||
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
|
||||
This is a custom Confluence class that:
|
||||
|
||||
A. overrides the default Confluence class to add a custom CQL method.
|
||||
B.
|
||||
This is necessary because the default Confluence class does not properly support cql expansions.
|
||||
All methods are automatically wrapped with handle_confluence_rate_limit.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
CREDENTIAL_PREFIX = "connector:confluence:credential"
|
||||
CREDENTIAL_TTL = 300 # 5 min
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
is_cloud: bool,
|
||||
url: str,
|
||||
credentials_provider: CredentialsProviderInterface,
|
||||
) -> None:
|
||||
self._is_cloud = is_cloud
|
||||
self._url = url.rstrip("/")
|
||||
self._credentials_provider = credentials_provider
|
||||
|
||||
self.redis_client: Redis | None = None
|
||||
self.static_credentials: dict[str, Any] | None = None
|
||||
if self._credentials_provider.is_dynamic():
|
||||
self.redis_client = get_redis_client(
|
||||
tenant_id=credentials_provider.get_tenant_id()
|
||||
)
|
||||
else:
|
||||
self.static_credentials = self._credentials_provider.get_credentials()
|
||||
|
||||
self._confluence = Confluence(url)
|
||||
self.credential_key: str = (
|
||||
self.CREDENTIAL_PREFIX
|
||||
+ f":credential_{self._credentials_provider.get_provider_key()}"
|
||||
)
|
||||
|
||||
self._kwargs: Any = None
|
||||
|
||||
self.shared_base_kwargs = {
|
||||
"api_version": "cloud" if is_cloud else "latest",
|
||||
"backoff_and_retry": True,
|
||||
"cloud": is_cloud,
|
||||
}
|
||||
|
||||
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
|
||||
"""credential_json - the current json credentials
|
||||
Returns a tuple
|
||||
1. The up to date credentials
|
||||
2. True if the credentials were updated
|
||||
|
||||
This method is intended to be used within a distributed lock.
|
||||
Lock, call this, update credentials if the tokens were refreshed, then release
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
wrap it with handle_confluence_rate_limit.
|
||||
"""
|
||||
for attr_name in dir(self):
|
||||
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
setattr(
|
||||
self,
|
||||
attr_name,
|
||||
handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
# static credentials are preloaded, so no locking/redis required
|
||||
if self.static_credentials:
|
||||
return self.static_credentials, False
|
||||
|
||||
if not self.redis_client:
|
||||
raise RuntimeError("self.redis_client is None")
|
||||
|
||||
# dynamic credentials need locking
|
||||
# check redis first, then fallback to the DB
|
||||
credential_raw = self.redis_client.get(self.credential_key)
|
||||
if credential_raw is not None:
|
||||
credential_bytes = cast(bytes, credential_raw)
|
||||
credential_str = credential_bytes.decode("utf-8")
|
||||
credential_json: dict[str, Any] = json.loads(credential_str)
|
||||
else:
|
||||
credential_json = self._credentials_provider.get_credentials()
|
||||
|
||||
if "confluence_refresh_token" not in credential_json:
|
||||
# static credentials ... cache them permanently and return
|
||||
self.static_credentials = credential_json
|
||||
return credential_json, False
|
||||
|
||||
# check if we should refresh tokens. we're deciding to refresh halfway
|
||||
# to expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
created_at = datetime.fromisoformat(credential_json["created_at"])
|
||||
expires_in: int = credential_json["expires_in"]
|
||||
renew_at = created_at + timedelta(seconds=expires_in // 2)
|
||||
if now <= renew_at:
|
||||
# cached/current credentials are reasonably up to date
|
||||
return credential_json, False
|
||||
|
||||
# we need to refresh
|
||||
logger.info("Renewing Confluence Cloud credentials...")
|
||||
new_credentials = confluence_refresh_tokens(
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID,
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET,
|
||||
credential_json["cloud_id"],
|
||||
credential_json["confluence_refresh_token"],
|
||||
)
|
||||
|
||||
# store the new credentials to redis and to the db thru the provider
|
||||
# redis: we use a 5 min TTL because we are given a 10 minute grace period
|
||||
# when keys are rotated. it's easier to expire the cached credentials
|
||||
# reasonably frequently rather than trying to handle strong synchronization
|
||||
# between the db and redis everywhere the credentials might be updated
|
||||
new_credential_str = json.dumps(new_credentials)
|
||||
self.redis_client.set(
|
||||
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
|
||||
)
|
||||
self._credentials_provider.set_credentials(new_credentials)
|
||||
|
||||
return new_credentials, True
|
||||
|
||||
@staticmethod
|
||||
def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
oauth2_dict: dict[str, Any] = {}
|
||||
if "confluence_refresh_token" in credentials:
|
||||
oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
oauth2_dict["token"] = {}
|
||||
oauth2_dict["token"]["access_token"] = credentials[
|
||||
"confluence_access_token"
|
||||
]
|
||||
return oauth2_dict
|
||||
|
||||
def _probe_connection(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
|
||||
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
logger.info("Probing Confluence with OAuth Access Token.")
|
||||
|
||||
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
|
||||
credentials
|
||||
)
|
||||
url = (
|
||||
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
)
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url, oauth2=oauth2_dict, **merged_kwargs
|
||||
)
|
||||
else:
|
||||
logger.info("Probing Confluence with Personal Access Token.")
|
||||
url = self._url
|
||||
if self._is_cloud:
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
username=credentials["confluence_username"],
|
||||
password=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
else:
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
token=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {url}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
|
||||
logger.info("Confluence probe succeeded.")
|
||||
|
||||
def _initialize_connection(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called externally to init the connection in a thread safe manner."""
|
||||
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
self._confluence = self._initialize_connection_helper(
|
||||
credentials, **merged_kwargs
|
||||
)
|
||||
self._kwargs = merged_kwargs
|
||||
|
||||
def _initialize_connection_helper(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Confluence:
|
||||
"""Called internally to init the connection. Distributed locking
|
||||
to prevent multiple threads from modifying the credentials
|
||||
must be handled around this function."""
|
||||
|
||||
confluence = None
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
logger.info("Connecting to Confluence Cloud with OAuth Access Token.")
|
||||
|
||||
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials)
|
||||
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
|
||||
else:
|
||||
logger.info("Connecting to Confluence with Personal Access Token.")
|
||||
if self._is_cloud:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
username=credentials["confluence_username"],
|
||||
password=credentials["confluence_access_token"],
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
token=credentials["confluence_access_token"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return confluence
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def _make_rate_limited_confluence_method(
|
||||
self, name: str, credential_provider: CredentialsProviderInterface | None
|
||||
) -> Callable[..., Any]:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
try:
|
||||
if credential_provider:
|
||||
with credential_provider:
|
||||
credentials, renewed = self._renew_credentials()
|
||||
if renewed:
|
||||
self._confluence = self._initialize_connection_helper(
|
||||
credentials, **self._kwargs
|
||||
)
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
return attr(*args, **kwargs)
|
||||
else:
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return wrapped_call
|
||||
|
||||
# def _wrap_methods(self) -> None:
|
||||
# """
|
||||
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
# wrap it with handle_confluence_rate_limit.
|
||||
# """
|
||||
# for attr_name in dir(self):
|
||||
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
# setattr(
|
||||
# self,
|
||||
# attr_name,
|
||||
# handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
# )
|
||||
|
||||
# def _ensure_token_valid(self) -> None:
|
||||
# if self._token_is_expired():
|
||||
# self._refresh_token()
|
||||
# # Re-init the Confluence client with the originally stored args
|
||||
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Dynamically intercept attribute/method access."""
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
# If it's not a method, just return it after ensuring token validity
|
||||
if not callable(attr):
|
||||
return attr
|
||||
|
||||
# skip methods that start with "_"
|
||||
if name.startswith("_"):
|
||||
return attr
|
||||
|
||||
# wrap the method with our retry handler
|
||||
rate_limited_method: Callable[
|
||||
..., Any
|
||||
] = self._make_rate_limited_confluence_method(name, self._credentials_provider)
|
||||
|
||||
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
|
||||
return rate_limited_method(*args, **kwargs)
|
||||
|
||||
return wrapped_method
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
|
||||
@ -507,63 +752,212 @@ class OnyxConfluence(Confluence):
|
||||
return response
|
||||
|
||||
|
||||
def _validate_connector_configuration(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> None:
|
||||
# test connection with direct client, no retries
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
url=wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=6,
|
||||
max_backoff_seconds=10,
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
global _USER_ID_TO_DISPLAY_NAME_CACHE
|
||||
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_userkey(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
if not found_display_name:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_accountid(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
|
||||
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
if "api.atlassian.com" in confluence_client.url:
|
||||
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
|
||||
if not parent_content_id:
|
||||
logger.warning(
|
||||
"parent_content_id is required to download attachments from Confluence Cloud!"
|
||||
)
|
||||
return None
|
||||
|
||||
download_link = (
|
||||
confluence_client.url
|
||||
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
|
||||
)
|
||||
else:
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
|
||||
# why are we using session.get here? we probably won't retry these ... is that ok?
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
return extracted_text
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
body = confluence_object["body"]
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
def build_confluence_client(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> OnyxConfluence:
|
||||
try:
|
||||
_validate_connector_configuration(
|
||||
credentials=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(str(e))
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=wiki_base.rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=10,
|
||||
max_backoff_seconds=60,
|
||||
cloud=is_cloud,
|
||||
)
|
||||
return format_document_soup(soup)
|
||||
|
@ -1,185 +1,38 @@
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
pass
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: "OnyxConfluence", user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
global _USER_ID_TO_DISPLAY_NAME_CACHE
|
||||
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_userkey(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
if not found_display_name:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_accountid(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
|
||||
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
body = confluence_object["body"]
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return format_document_soup(soup)
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
@ -193,49 +46,6 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
]
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
@ -284,6 +94,137 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
return datetime_object
|
||||
|
||||
|
||||
def confluence_refresh_tokens(
|
||||
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
|
||||
) -> dict[str, Any]:
|
||||
# rotate the refresh and access token
|
||||
# Note that access tokens are only good for an hour in confluence cloud,
|
||||
# so we're going to have problems if the connector runs for longer
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair
|
||||
response = requests.post(
|
||||
CONFLUENCE_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
token_response = TokenResponse.model_validate_json(response.text)
|
||||
except Exception:
|
||||
raise RuntimeError("Confluence Cloud token refresh failed.")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=token_response.expires_in)
|
||||
|
||||
new_credentials: dict[str, Any] = {}
|
||||
new_credentials["confluence_access_token"] = token_response.access_token
|
||||
new_credentials["confluence_refresh_token"] = token_response.refresh_token
|
||||
new_credentials["created_at"] = now.isoformat()
|
||||
new_credentials["expires_at"] = expires_at.isoformat()
|
||||
new_credentials["expires_in"] = token_response.expires_in
|
||||
new_credentials["scope"] = token_response.scope
|
||||
new_credentials["cloud_id"] = cloud_id
|
||||
return new_credentials
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except requests.HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
def get_single_param_from_url(url: str, param: str) -> str | None:
|
||||
"""Get a parameter from a url"""
|
||||
parsed_url = urlparse(url)
|
||||
|
135
backend/onyx/connectors/credentials_provider.py
Normal file
135
backend/onyx/connectors/credentials_provider.py
Normal file
@ -0,0 +1,135 @@
|
||||
import uuid
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import Credential
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
class OnyxDBCredentialsProvider(
|
||||
CredentialsProviderInterface["OnyxDBCredentialsProvider"]
|
||||
):
|
||||
"""Implementation to allow the connector to callback and update credentials in the db.
|
||||
Required in cases where credentials can rotate while the connector is running.
|
||||
"""
|
||||
|
||||
LOCK_TTL = 900 # TTL of the lock
|
||||
|
||||
def __init__(self, tenant_id: str, connector_name: str, credential_id: int):
|
||||
self._tenant_id = tenant_id
|
||||
self._connector_name = connector_name
|
||||
self._credential_id = credential_id
|
||||
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# lock used to prevent overlapping renewal of credentials
|
||||
self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}"
|
||||
self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL)
|
||||
|
||||
def __enter__(self) -> "OnyxDBCredentialsProvider":
|
||||
acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL)
|
||||
if not acquired:
|
||||
raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}")
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
"""Release the lock when exiting the context."""
|
||||
if self._lock and self._lock.owned():
|
||||
self._lock.release()
|
||||
|
||||
def get_tenant_id(self) -> str | None:
|
||||
return self._tenant_id
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return str(self._credential_id)
|
||||
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
credential = db_session.execute(
|
||||
select(Credential).where(Credential.id == self._credential_id)
|
||||
).scalar_one()
|
||||
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
return credential.credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
try:
|
||||
credential = db_session.execute(
|
||||
select(Credential)
|
||||
.where(Credential.id == self._credential_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
credential.credential_json = credential_json
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class OnyxStaticCredentialsProvider(
|
||||
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
|
||||
):
|
||||
"""Implementation (a very simple one!) to handle static credentials."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
connector_name: str,
|
||||
credential_json: dict[str, Any],
|
||||
):
|
||||
self._tenant_id = tenant_id
|
||||
self._connector_name = connector_name
|
||||
self._credential_json = credential_json
|
||||
|
||||
self._provider_key = str(uuid.uuid4())
|
||||
|
||||
def __enter__(self) -> "OnyxStaticCredentialsProvider":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def get_tenant_id(self) -> str | None:
|
||||
return self._tenant_id
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return self._provider_key
|
||||
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
return self._credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
self._credential_json = credential_json
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return False
|
@ -12,6 +12,7 @@ from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
@ -32,6 +33,7 @@ from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@ -57,6 +59,7 @@ from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.models import Credential
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
class ConnectorMissingException(Exception):
|
||||
@ -167,10 +170,17 @@ def instantiate_connector(
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
if isinstance(connector, CredentialsConnector):
|
||||
provider = OnyxDBCredentialsProvider(
|
||||
get_current_tenant_id(), str(source), credential.id
|
||||
)
|
||||
connector.set_credentials_provider(provider)
|
||||
else:
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
return connector
|
||||
|
||||
|
@ -1,7 +1,10 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -111,6 +114,69 @@ class OAuthConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
T = TypeVar("T", bound="CredentialsProviderInterface")
|
||||
|
||||
|
||||
class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||
@abc.abstractmethod
|
||||
def __enter__(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_tenant_id(self) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_provider_key(self) -> str:
|
||||
"""a unique key that the connector can use to lock around a credential
|
||||
that might be used simultaneously.
|
||||
|
||||
Will typically be the credential id, but can also just be something random
|
||||
in cases when there is nothing to lock (aka static credentials)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dynamic(self) -> bool:
|
||||
"""If dynamic, the credentials may change during usage ... maening the client
|
||||
needs to use the locking features of the credentials provider to operate
|
||||
correctly.
|
||||
|
||||
If static, the client can simply reference the credentials once and use them
|
||||
through the entire indexing run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CredentialsConnector(BaseConnector):
|
||||
"""Implement this if the connector needs to be able to read and write credentials
|
||||
on the fly. Typically used with shared credentials/tokens that might be renewed
|
||||
at any time."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
@ -302,29 +302,29 @@ class WebConnector(LoadConnector):
|
||||
playwright, context = start_playwright()
|
||||
restart_playwright = False
|
||||
while to_visit:
|
||||
current_url = to_visit.pop()
|
||||
if current_url in visited_links:
|
||||
initial_url = to_visit.pop()
|
||||
if initial_url in visited_links:
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
visited_links.add(initial_url)
|
||||
|
||||
try:
|
||||
protected_url_check(current_url)
|
||||
protected_url_check(initial_url)
|
||||
except Exception as e:
|
||||
last_error = f"Invalid URL {current_url} due to {e}"
|
||||
last_error = f"Invalid URL {initial_url} due to {e}"
|
||||
logger.warning(last_error)
|
||||
continue
|
||||
|
||||
logger.info(f"Visiting {current_url}")
|
||||
logger.info(f"{len(visited_links)}: Visiting {initial_url}")
|
||||
|
||||
try:
|
||||
check_internet_connection(current_url)
|
||||
check_internet_connection(initial_url)
|
||||
if restart_playwright:
|
||||
playwright, context = start_playwright()
|
||||
restart_playwright = False
|
||||
|
||||
if current_url.split(".")[-1] == "pdf":
|
||||
if initial_url.split(".")[-1] == "pdf":
|
||||
# PDF files are not checked for links
|
||||
response = requests.get(current_url)
|
||||
response = requests.get(initial_url)
|
||||
page_text, metadata = read_pdf_file(
|
||||
file=io.BytesIO(response.content)
|
||||
)
|
||||
@ -332,10 +332,10 @@ class WebConnector(LoadConnector):
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=current_url,
|
||||
sections=[Section(link=current_url, text=page_text)],
|
||||
id=initial_url,
|
||||
sections=[Section(link=initial_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=current_url.split("/")[-1],
|
||||
semantic_identifier=initial_url.split("/")[-1],
|
||||
metadata=metadata,
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
@ -347,21 +347,25 @@ class WebConnector(LoadConnector):
|
||||
continue
|
||||
|
||||
page = context.new_page()
|
||||
page_response = page.goto(current_url)
|
||||
page_response = page.goto(initial_url)
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified")
|
||||
if page_response
|
||||
else None
|
||||
)
|
||||
final_page = page.url
|
||||
if final_page != current_url:
|
||||
logger.info(f"Redirected to {final_page}")
|
||||
protected_url_check(final_page)
|
||||
current_url = final_page
|
||||
if current_url in visited_links:
|
||||
logger.info("Redirected page already indexed")
|
||||
final_url = page.url
|
||||
if final_url != initial_url:
|
||||
protected_url_check(final_url)
|
||||
initial_url = final_url
|
||||
if initial_url in visited_links:
|
||||
logger.info(
|
||||
f"{len(visited_links)}: {initial_url} redirected to {final_url} - already indexed"
|
||||
)
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
logger.info(
|
||||
f"{len(visited_links)}: {initial_url} redirected to {final_url}"
|
||||
)
|
||||
visited_links.add(initial_url)
|
||||
|
||||
if self.scroll_before_scraping:
|
||||
scroll_attempts = 0
|
||||
@ -379,13 +383,13 @@ class WebConnector(LoadConnector):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
if self.recursive:
|
||||
internal_links = get_internal_links(base_url, current_url, soup)
|
||||
internal_links = get_internal_links(base_url, initial_url, soup)
|
||||
for link in internal_links:
|
||||
if link not in visited_links:
|
||||
to_visit.append(link)
|
||||
|
||||
if page_response and str(page_response.status)[0] in ("4", "5"):
|
||||
last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response"
|
||||
last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response"
|
||||
logger.info(last_error)
|
||||
continue
|
||||
|
||||
@ -393,12 +397,12 @@ class WebConnector(LoadConnector):
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=current_url,
|
||||
id=initial_url,
|
||||
sections=[
|
||||
Section(link=current_url, text=parsed_html.cleaned_text)
|
||||
Section(link=initial_url, text=parsed_html.cleaned_text)
|
||||
],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=parsed_html.title or current_url,
|
||||
semantic_identifier=parsed_html.title or initial_url,
|
||||
metadata={},
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
@ -410,7 +414,7 @@ class WebConnector(LoadConnector):
|
||||
|
||||
page.close()
|
||||
except Exception as e:
|
||||
last_error = f"Failed to fetch '{current_url}': {e}"
|
||||
last_error = f"Failed to fetch '{initial_url}': {e}"
|
||||
logger.exception(last_error)
|
||||
playwright.stop()
|
||||
restart_playwright = True
|
||||
|
@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.standard_oauth import router as oauth_router
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.folder.api import router as folder_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
@ -323,7 +322,6 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
@ -46,13 +46,21 @@ def mask_string(sensitive_str: str) -> str:
|
||||
return "****...**" + sensitive_str[-4:]
|
||||
|
||||
|
||||
MASK_CREDENTIALS_WHITELIST = {
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
"wiki_base",
|
||||
"cloud_name",
|
||||
"cloud_id",
|
||||
}
|
||||
|
||||
|
||||
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
||||
masked_creds = {}
|
||||
for key, val in credential_dict.items():
|
||||
if isinstance(val, str):
|
||||
# we want to pass the authentication_method field through so the frontend
|
||||
# can disambiguate credentials created by different methods
|
||||
if key == DB_CREDENTIALS_AUTHENTICATION_METHOD:
|
||||
if key in MASK_CREDENTIALS_WHITELIST:
|
||||
masked_creds[key] = val
|
||||
else:
|
||||
masked_creds[key] = mask_string(val)
|
||||
@ -63,8 +71,8 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to mask credentials of type other than string, cannot process request."
|
||||
f"Recieved type: {type(val)}"
|
||||
f"Unable to mask credentials of type other than string or int, cannot process request."
|
||||
f"Received type: {type(val)}"
|
||||
)
|
||||
|
||||
return masked_creds
|
||||
|
@ -5,7 +5,9 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
|
||||
@ -18,12 +20,15 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
|
||||
)
|
||||
|
||||
connector.load_credentials(
|
||||
credentials_provider = OnyxStaticCredentialsProvider(
|
||||
None,
|
||||
DocumentSource.CONFLUENCE,
|
||||
{
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
}
|
||||
},
|
||||
)
|
||||
connector.set_credentials_provider(credentials_provider)
|
||||
return connector
|
||||
|
||||
|
||||
|
@ -2,7 +2,9 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -11,12 +13,16 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
wiki_base="https://danswerai.atlassian.net",
|
||||
is_cloud=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
|
||||
credentials_provider = OnyxStaticCredentialsProvider(
|
||||
None,
|
||||
DocumentSource.CONFLUENCE,
|
||||
{
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
}
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
},
|
||||
)
|
||||
connector.set_credentials_provider(credentials_provider)
|
||||
return connector
|
||||
|
||||
|
||||
|
@ -3,9 +3,7 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
handle_confluence_rate_limit,
|
||||
)
|
||||
from onyx.connectors.confluence.utils import handle_confluence_rate_limit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -50,6 +48,8 @@ def mock_confluence_call() -> Mock:
|
||||
# mock_sleep.assert_called_with(int(retry_after))
|
||||
|
||||
|
||||
# NOTE(rkuo): This tests an older version of rate limiting that is being deprecated
|
||||
# and probably should go away soon.
|
||||
def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
|
||||
mock_confluence_call.side_effect = HTTPError(
|
||||
response=Mock(status_code=500, text="Internal Server Error")
|
||||
|
@ -100,11 +100,6 @@ export default function LabelManagement() {
|
||||
width="w-full max-w-xs"
|
||||
name={`editLabelName_${label.id}`}
|
||||
label="Label Name"
|
||||
value={
|
||||
values.editLabelId === label.id
|
||||
? values.editLabelName
|
||||
: label.name
|
||||
}
|
||||
onChange={(e) => {
|
||||
setFieldValue("editLabelId", label.id);
|
||||
setFieldValue("editLabelName", e.target.value);
|
||||
|
@ -52,7 +52,6 @@ export default function StarterMessagesList({
|
||||
<TextFormField
|
||||
name={`starter_messages.${index}.message`}
|
||||
label=""
|
||||
value={starterMessage.message}
|
||||
onChange={(e) => handleInputChange(index, e.target.value)}
|
||||
className="flex-grow"
|
||||
removeLabel
|
||||
|
@ -193,13 +193,15 @@ export default function AddConnector({
|
||||
// Check if there are no credentials
|
||||
const noCredentials = credentialTemplate == null;
|
||||
|
||||
if (noCredentials && 1 != formStep) {
|
||||
setFormStep(Math.max(1, formStep));
|
||||
}
|
||||
useEffect(() => {
|
||||
if (noCredentials && 1 != formStep) {
|
||||
setFormStep(Math.max(1, formStep));
|
||||
}
|
||||
|
||||
if (!noCredentials && !credentialActivated && formStep != 0) {
|
||||
setFormStep(Math.min(formStep, 0));
|
||||
}
|
||||
if (!noCredentials && !credentialActivated && formStep != 0) {
|
||||
setFormStep(Math.min(formStep, 0));
|
||||
}
|
||||
}, [noCredentials, formStep, setFormStep]);
|
||||
|
||||
const convertStringToDateTime = (indexingStart: string | null) => {
|
||||
return indexingStart ? new Date(indexingStart) : null;
|
||||
|
@ -33,7 +33,7 @@ export default function OAuthCallbackPage() {
|
||||
const connector = pathname?.split("/")[3];
|
||||
|
||||
useEffect(() => {
|
||||
const handleOAuthCallback = async () => {
|
||||
const onFirstLoad = async () => {
|
||||
// Examples
|
||||
// connector (url segment)= "google-drive"
|
||||
// sourceType (for looking up metadata) = "google_drive"
|
||||
@ -85,10 +85,19 @@ export default function OAuthCallbackPage() {
|
||||
}
|
||||
|
||||
setStatusMessage("Success!");
|
||||
setStatusDetails(
|
||||
`Your authorization with ${sourceMetadata.displayName} completed successfully.`
|
||||
);
|
||||
setRedirectUrl(response.redirect_on_success); // Extract the redirect URL
|
||||
|
||||
// set the continuation link
|
||||
if (response.finalize_url) {
|
||||
setRedirectUrl(response.finalize_url);
|
||||
setStatusDetails(
|
||||
`Your authorization with ${sourceMetadata.displayName} completed successfully. Additional steps are required to complete credential setup.`
|
||||
);
|
||||
} else {
|
||||
setRedirectUrl(response.redirect_on_success);
|
||||
setStatusDetails(
|
||||
`Your authorization with ${sourceMetadata.displayName} completed successfully.`
|
||||
);
|
||||
}
|
||||
setIsError(false);
|
||||
} catch (error) {
|
||||
console.error("OAuth error:", error);
|
||||
@ -100,15 +109,15 @@ export default function OAuthCallbackPage() {
|
||||
}
|
||||
};
|
||||
|
||||
handleOAuthCallback();
|
||||
onFirstLoad();
|
||||
}, [code, state, connector]);
|
||||
|
||||
return (
|
||||
<div className="container mx-auto py-8">
|
||||
<div className="mx-auto h-screen flex flex-col">
|
||||
<AdminPageTitle title={pageTitle} icon={<KeyIcon size={32} />} />
|
||||
|
||||
<div className="flex flex-col items-center justify-center min-h-screen">
|
||||
<CardSection className="max-w-md">
|
||||
<div className="flex-1 flex flex-col items-center justify-center">
|
||||
<CardSection className="max-w-md w-[500px] h-[250px] p-8">
|
||||
<h1 className="text-2xl font-bold mb-4">{statusMessage}</h1>
|
||||
<p className="text-text-500">{statusDetails}</p>
|
||||
{redirectUrl && !isError && (
|
||||
|
293
web/src/app/admin/connectors/[connector]/oauth/finalize/page.tsx
Normal file
293
web/src/app/admin/connectors/[connector]/oauth/finalize/page.tsx
Normal file
@ -0,0 +1,293 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import Title from "@/components/ui/title";
|
||||
import { KeyIcon } from "@/components/icons/icons";
|
||||
import { getSourceMetadata, isValidSource } from "@/lib/sources";
|
||||
import { ConfluenceAccessibleResource, ValidSources } from "@/lib/types";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import {
|
||||
handleOAuthAuthorizationResponse,
|
||||
handleOAuthConfluenceFinalize,
|
||||
handleOAuthPrepareFinalization,
|
||||
} from "@/lib/oauth_utils";
|
||||
import { SelectorFormField } from "@/components/admin/connectors/Field";
|
||||
import { ErrorMessage, Field, Form, Formik, useFormikContext } from "formik";
|
||||
import * as Yup from "yup";
|
||||
|
||||
// Helper component to keep the effect logic clean:
|
||||
function UpdateCloudURLOnCloudIdChange({
|
||||
accessibleResources,
|
||||
}: {
|
||||
accessibleResources: ConfluenceAccessibleResource[];
|
||||
}) {
|
||||
const { values, setValues, setFieldValue } = useFormikContext<{
|
||||
cloud_id: string;
|
||||
cloud_name: string;
|
||||
cloud_url: string;
|
||||
}>();
|
||||
|
||||
useEffect(() => {
|
||||
// Whenever cloud_id changes, find the matching resource and update cloud_url
|
||||
if (values.cloud_id) {
|
||||
const selectedResource = accessibleResources.find(
|
||||
(resource) => resource.id === values.cloud_id
|
||||
);
|
||||
if (selectedResource) {
|
||||
// Update multiple fields together ... somehow setting them in sequence
|
||||
// doesn't work with the validator
|
||||
// it may also be possible to await each setFieldValue call.
|
||||
// https://github.com/jaredpalmer/formik/issues/2266
|
||||
setValues((prevValues) => ({
|
||||
...prevValues,
|
||||
cloud_name: selectedResource.name,
|
||||
cloud_url: selectedResource.url,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}, [values.cloud_id, accessibleResources, setFieldValue]);
|
||||
|
||||
// This component doesn't render anything visible:
|
||||
return null;
|
||||
}
|
||||
|
||||
export default function OAuthFinalizePage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
const [statusMessage, setStatusMessage] = useState("Processing...");
|
||||
const [statusDetails, setStatusDetails] = useState(
|
||||
"Please wait while we complete the setup."
|
||||
);
|
||||
const [redirectUrl, setRedirectUrl] = useState<string | null>(null);
|
||||
const [isError, setIsError] = useState(false);
|
||||
const [isSubmitted, setIsSubmitted] = useState(false); // New state
|
||||
const [pageTitle, setPageTitle] = useState(
|
||||
"Finalize Authorization with Third-Party service"
|
||||
);
|
||||
|
||||
const [accessibleResources, setAccessibleResources] = useState<
|
||||
ConfluenceAccessibleResource[]
|
||||
>([]);
|
||||
|
||||
// Extract query parameters
|
||||
const credentialParam = searchParams.get("credential");
|
||||
const credential = credentialParam ? parseInt(credentialParam, 10) : NaN;
|
||||
const pathname = usePathname();
|
||||
const connector = pathname?.split("/")[3];
|
||||
|
||||
useEffect(() => {
|
||||
const onFirstLoad = async () => {
|
||||
// Examples
|
||||
// connector (url segment)= "google-drive"
|
||||
// sourceType (for looking up metadata) = "google_drive"
|
||||
|
||||
if (isNaN(credential)) {
|
||||
setStatusMessage("Improperly formed OAuth finalization request.");
|
||||
setStatusDetails("Invalid or missing credential id.");
|
||||
setIsError(true);
|
||||
return;
|
||||
}
|
||||
|
||||
const sourceType = connector.replaceAll("-", "_");
|
||||
if (!isValidSource(sourceType)) {
|
||||
setStatusMessage(
|
||||
`The specified connector source type ${sourceType} does not exist.`
|
||||
);
|
||||
setStatusDetails(`${sourceType} is not a valid source type.`);
|
||||
setIsError(true);
|
||||
return;
|
||||
}
|
||||
|
||||
const sourceMetadata = getSourceMetadata(sourceType as ValidSources);
|
||||
setPageTitle(`Finalize Authorization with ${sourceMetadata.displayName}`);
|
||||
|
||||
setStatusMessage("Processing...");
|
||||
setStatusDetails(
|
||||
"Please wait while we retrieve a list of your accessible sites."
|
||||
);
|
||||
setIsError(false); // Ensure no error state during loading
|
||||
|
||||
try {
|
||||
const response = await handleOAuthPrepareFinalization(
|
||||
connector,
|
||||
credential
|
||||
);
|
||||
|
||||
if (!response) {
|
||||
throw new Error("Empty response from OAuth server.");
|
||||
}
|
||||
|
||||
setAccessibleResources(response.accessible_resources);
|
||||
|
||||
setStatusMessage("Select a Confluence site");
|
||||
setStatusDetails("");
|
||||
|
||||
setIsError(false);
|
||||
} catch (error) {
|
||||
console.error("OAuth finalization error:", error);
|
||||
setStatusMessage("Oops, something went wrong!");
|
||||
setStatusDetails(
|
||||
"An error occurred during the OAuth finalization process. Please try again."
|
||||
);
|
||||
setIsError(true);
|
||||
}
|
||||
};
|
||||
|
||||
onFirstLoad();
|
||||
}, [credential, connector]);
|
||||
|
||||
useEffect(() => {}, [redirectUrl]);
|
||||
|
||||
return (
|
||||
<div className="mx-auto h-screen flex flex-col">
|
||||
<AdminPageTitle title={pageTitle} icon={<KeyIcon size={32} />} />
|
||||
|
||||
<div className="flex-1 flex flex-col items-center justify-center">
|
||||
<CardSection className="max-w-md w-[500px] h-[250px] p-8">
|
||||
<h1 className="text-2xl font-bold mb-4">{statusMessage}</h1>
|
||||
<p className="text-text-500">{statusDetails}</p>
|
||||
|
||||
<Formik
|
||||
initialValues={{
|
||||
credential_id: credential,
|
||||
cloud_id: "",
|
||||
cloud_name: "",
|
||||
cloud_url: "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
credential_id: Yup.number().required(
|
||||
"Credential ID is required."
|
||||
),
|
||||
cloud_id: Yup.string().required(
|
||||
"You must select a Confluence site (id not found)."
|
||||
),
|
||||
cloud_name: Yup.string().required(
|
||||
"You must select a Confluence site (name not found)."
|
||||
),
|
||||
cloud_url: Yup.string().required(
|
||||
"You must select a Confluence site (url not found)."
|
||||
),
|
||||
})}
|
||||
validateOnMount
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
if (!values.cloud_id) {
|
||||
throw new Error("Cloud ID is required.");
|
||||
}
|
||||
|
||||
if (!values.cloud_name) {
|
||||
throw new Error("Cloud URL is required.");
|
||||
}
|
||||
|
||||
if (!values.cloud_url) {
|
||||
throw new Error("Cloud URL is required.");
|
||||
}
|
||||
|
||||
const response = await handleOAuthConfluenceFinalize(
|
||||
values.credential_id,
|
||||
values.cloud_id,
|
||||
values.cloud_name,
|
||||
values.cloud_url
|
||||
);
|
||||
formikHelpers.setSubmitting(false);
|
||||
|
||||
if (response) {
|
||||
setRedirectUrl(response.redirect_url);
|
||||
setStatusMessage("Confluence authorization finalized.");
|
||||
}
|
||||
|
||||
setIsSubmitted(true); // Mark as submitted
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
setStatusMessage("Error during submission.");
|
||||
setStatusDetails(
|
||||
"An error occurred during the submission process. Please try again."
|
||||
);
|
||||
setIsError(true);
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, isValid, setFieldValue }) => (
|
||||
<Form>
|
||||
{/* Debug info
|
||||
<div className="mb-4 p-2 bg-gray-100 rounded text-xs">
|
||||
<pre>
|
||||
isValid: {String(isValid)}
|
||||
errors: {JSON.stringify(errors, null, 2)}
|
||||
values: {JSON.stringify(values, null, 2)}
|
||||
</pre>
|
||||
</div> */}
|
||||
|
||||
{/* Our helper component that reacts to changes in cloud_id */}
|
||||
<UpdateCloudURLOnCloudIdChange
|
||||
accessibleResources={accessibleResources}
|
||||
/>
|
||||
|
||||
<Field type="hidden" name="cloud_name" />
|
||||
<ErrorMessage
|
||||
name="cloud_name"
|
||||
component="div"
|
||||
className="error"
|
||||
/>
|
||||
|
||||
<Field type="hidden" name="cloud_url" />
|
||||
<ErrorMessage
|
||||
name="cloud_url"
|
||||
component="div"
|
||||
className="error"
|
||||
/>
|
||||
|
||||
{!redirectUrl && accessibleResources.length > 0 && (
|
||||
<SelectorFormField
|
||||
name="cloud_id"
|
||||
options={accessibleResources.map((resource) => ({
|
||||
name: `${resource.name} - ${resource.url}`,
|
||||
value: resource.id,
|
||||
}))}
|
||||
onSelect={(selectedValue) => {
|
||||
const selectedResource = accessibleResources.find(
|
||||
(resource) => resource.id === selectedValue
|
||||
);
|
||||
if (selectedResource) {
|
||||
setFieldValue("cloud_id", selectedResource.id);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<br />
|
||||
{!redirectUrl && (
|
||||
<Button
|
||||
type="submit"
|
||||
size="sm"
|
||||
variant="submit"
|
||||
disabled={!isValid || isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Submitting..." : "Submit"}
|
||||
</Button>
|
||||
)}
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
{redirectUrl && !isError && (
|
||||
<div className="mt-4">
|
||||
<p className="text-sm">
|
||||
Authorization finalized. Click{" "}
|
||||
<a href={redirectUrl} className="text-blue-500 underline">
|
||||
here
|
||||
</a>{" "}
|
||||
to continue.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</CardSection>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -1,4 +1,10 @@
|
||||
import React, { Dispatch, FC, SetStateAction, useState } from "react";
|
||||
import React, {
|
||||
Dispatch,
|
||||
FC,
|
||||
SetStateAction,
|
||||
useEffect,
|
||||
useState,
|
||||
} from "react";
|
||||
import CredentialSubText from "@/components/credentials/CredentialFields";
|
||||
import { ConnectionConfiguration } from "@/lib/connectors/connectors";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
@ -8,6 +14,7 @@ import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTyp
|
||||
import { ConfigurableSources } from "@/lib/types";
|
||||
import { Credential } from "@/lib/connectors/credentials";
|
||||
import { RenderField } from "./FieldRendering";
|
||||
import { useFormikContext } from "formik";
|
||||
|
||||
export interface DynamicConnectionFormProps {
|
||||
config: ConnectionConfiguration;
|
||||
@ -22,7 +29,25 @@ const DynamicConnectionForm: FC<DynamicConnectionFormProps> = ({
|
||||
connector,
|
||||
currentCredential,
|
||||
}) => {
|
||||
const { setFieldValue } = useFormikContext<any>(); // Get Formik's context functions
|
||||
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
const [connectorNameInitialized, setConnectorNameInitialized] =
|
||||
useState(false);
|
||||
|
||||
let initialConnectorName = "";
|
||||
if (config.initialConnectorName) {
|
||||
initialConnectorName =
|
||||
currentCredential?.credential_json?.[config.initialConnectorName] ?? "";
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
const field_value = values["name"];
|
||||
if (initialConnectorName && !connectorNameInitialized && !field_value) {
|
||||
setFieldValue("name", initialConnectorName);
|
||||
setConnectorNameInitialized(true);
|
||||
}
|
||||
}, [initialConnectorName, setFieldValue, values]);
|
||||
|
||||
return (
|
||||
<>
|
||||
|
@ -1,4 +1,4 @@
|
||||
import React, { Dispatch, FC, SetStateAction } from "react";
|
||||
import React, { Dispatch, FC, SetStateAction, useEffect } from "react";
|
||||
import { AdminBooleanFormField } from "@/components/credentials/CredentialFields";
|
||||
import { FileUpload } from "@/components/admin/connectors/FileUpload";
|
||||
import { TabOption } from "@/lib/connectors/connectors";
|
||||
@ -16,6 +16,7 @@ import {
|
||||
TabsList,
|
||||
TabsTrigger,
|
||||
} from "@/components/ui/fully_wrapped_tabs";
|
||||
import { useFormikContext } from "formik";
|
||||
|
||||
interface TabsFieldProps {
|
||||
tabField: TabOption;
|
||||
@ -123,6 +124,8 @@ export const RenderField: FC<RenderFieldProps> = ({
|
||||
connector,
|
||||
currentCredential,
|
||||
}) => {
|
||||
const { setFieldValue } = useFormikContext<any>(); // Get Formik's context functions
|
||||
|
||||
const label =
|
||||
typeof field.label === "function"
|
||||
? field.label(currentCredential)
|
||||
@ -131,6 +134,22 @@ export const RenderField: FC<RenderFieldProps> = ({
|
||||
typeof field.description === "function"
|
||||
? field.description(currentCredential)
|
||||
: field.description;
|
||||
const disabled =
|
||||
typeof field.disabled === "function"
|
||||
? field.disabled(currentCredential)
|
||||
: (field.disabled ?? false);
|
||||
const initialValue =
|
||||
typeof field.initial === "function"
|
||||
? field.initial(currentCredential)
|
||||
: (field.initial ?? "");
|
||||
|
||||
// if initialValue exists, prepopulate the field with it
|
||||
useEffect(() => {
|
||||
const field_value = values[field.name];
|
||||
if (initialValue && field_value === undefined) {
|
||||
setFieldValue(field.name, initialValue);
|
||||
}
|
||||
}, [field.name, initialValue, setFieldValue, values]);
|
||||
|
||||
if (field.type === "tab") {
|
||||
return (
|
||||
@ -176,6 +195,8 @@ export const RenderField: FC<RenderFieldProps> = ({
|
||||
subtext={description}
|
||||
name={field.name}
|
||||
label={label}
|
||||
disabled={disabled}
|
||||
onChange={(e) => setFieldValue(field.name, e.target.value)}
|
||||
/>
|
||||
) : field.type === "text" ? (
|
||||
<TextFormField
|
||||
@ -186,6 +207,8 @@ export const RenderField: FC<RenderFieldProps> = ({
|
||||
name={field.name}
|
||||
isTextArea={field.isTextArea || false}
|
||||
defaultHeight={"h-15"}
|
||||
disabled={disabled}
|
||||
onChange={(e) => setFieldValue(field.name, e.target.value)}
|
||||
/>
|
||||
) : field.type === "string_tab" ? (
|
||||
<div className="text-center">{description}</div>
|
||||
|
@ -132,7 +132,6 @@ export function TextFormField({
|
||||
label,
|
||||
subtext,
|
||||
placeholder,
|
||||
value,
|
||||
type = "text",
|
||||
optional,
|
||||
includeRevert,
|
||||
@ -157,7 +156,6 @@ export function TextFormField({
|
||||
vertical,
|
||||
className,
|
||||
}: {
|
||||
value?: string; // Escape hatch for setting the value of the field - conflicts with Formik
|
||||
name: string;
|
||||
removeLabel?: boolean;
|
||||
label: string;
|
||||
@ -253,7 +251,6 @@ export function TextFormField({
|
||||
min={min}
|
||||
as={isTextArea ? "textarea" : "input"}
|
||||
type={type}
|
||||
defaultValue={value}
|
||||
data-testid={name}
|
||||
name={name}
|
||||
id={name}
|
||||
|
@ -130,6 +130,7 @@ interface BooleanFormFieldProps {
|
||||
small?: boolean;
|
||||
alignTop?: boolean;
|
||||
noLabel?: boolean;
|
||||
disabled?: boolean;
|
||||
onChange?: (e: React.ChangeEvent<HTMLInputElement>) => void;
|
||||
}
|
||||
|
||||
@ -141,6 +142,7 @@ export const AdminBooleanFormField = ({
|
||||
small,
|
||||
checked,
|
||||
alignTop,
|
||||
disabled = false,
|
||||
onChange,
|
||||
}: BooleanFormFieldProps) => {
|
||||
const [field, meta, helpers] = useField(name);
|
||||
@ -152,6 +154,7 @@ export const AdminBooleanFormField = ({
|
||||
type="checkbox"
|
||||
{...field}
|
||||
checked={Boolean(field.value)}
|
||||
disabled={disabled}
|
||||
onChange={(e) => {
|
||||
helpers.setValue(e.target.checked);
|
||||
}}
|
||||
|
@ -43,6 +43,7 @@ export interface Option {
|
||||
currentCredential: Credential<any> | null
|
||||
) => boolean;
|
||||
wrapInCollapsible?: boolean;
|
||||
disabled?: boolean | ((currentCredential: Credential<any> | null) => boolean);
|
||||
}
|
||||
|
||||
export interface SelectOption extends Option {
|
||||
@ -60,6 +61,7 @@ export interface ListOption extends Option {
|
||||
export interface TextOption extends Option {
|
||||
type: "text";
|
||||
default?: string;
|
||||
initial?: string | ((currentCredential: Credential<any> | null) => string);
|
||||
isTextArea?: boolean;
|
||||
}
|
||||
|
||||
@ -105,6 +107,7 @@ export interface TabOption extends Option {
|
||||
export interface ConnectionConfiguration {
|
||||
description: string;
|
||||
subtext?: string;
|
||||
initialConnectorName?: string; // a key in the credential to prepopulate the connector name field
|
||||
values: (
|
||||
| BooleanOption
|
||||
| ListOption
|
||||
@ -389,6 +392,7 @@ export const connectorConfigs: Record<
|
||||
},
|
||||
confluence: {
|
||||
description: "Configure Confluence connector",
|
||||
initialConnectorName: "cloud_name",
|
||||
values: [
|
||||
{
|
||||
type: "checkbox",
|
||||
@ -399,6 +403,12 @@ export const connectorConfigs: Record<
|
||||
default: true,
|
||||
description:
|
||||
"Check if this is a Confluence Cloud instance, uncheck for Confluence Server/Data Center",
|
||||
disabled: (currentCredential) => {
|
||||
if (currentCredential?.credential_json?.confluence_refresh_token) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
@ -406,6 +416,15 @@ export const connectorConfigs: Record<
|
||||
label: "Wiki Base URL",
|
||||
name: "wiki_base",
|
||||
optional: false,
|
||||
initial: (currentCredential) => {
|
||||
return currentCredential?.credential_json?.wiki_base ?? "";
|
||||
},
|
||||
disabled: (currentCredential) => {
|
||||
if (currentCredential?.credential_json?.confluence_refresh_token) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
description:
|
||||
"The base URL of your Confluence instance (e.g., https://your-domain.atlassian.net/wiki)",
|
||||
},
|
||||
|
@ -27,6 +27,9 @@ export async function getConnectorOauthRedirectUrl(
|
||||
export function useOAuthDetails(sourceType: ValidSources) {
|
||||
return useSWR<OAuthDetails>(
|
||||
`/api/connector/oauth/details/${sourceType}`,
|
||||
errorHandlingFetcher
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
shouldRetryOnError: false,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
import {
|
||||
OAuthGoogleDriveCallbackResponse,
|
||||
OAuthBaseCallbackResponse,
|
||||
OAuthConfluenceFinalizeResponse,
|
||||
OAuthConfluencePrepareFinalizationResponse,
|
||||
OAuthPrepareAuthorizationResponse,
|
||||
OAuthSlackCallbackResponse,
|
||||
} from "./types";
|
||||
@ -53,6 +55,10 @@ export async function handleOAuthAuthorizationResponse(
|
||||
return handleOAuthGoogleDriveAuthorizationResponse(code, state);
|
||||
}
|
||||
|
||||
if (connector === "confluence") {
|
||||
return handleOAuthConfluenceAuthorizationResponse(code, state);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@ -75,7 +81,7 @@ export async function handleOAuthSlackAuthorizationResponse(
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorDetails = `Failed to handle OAuth authorization response: ${response.status}`;
|
||||
let errorDetails = `Failed to handle OAuth Slack authorization response: ${response.status}`;
|
||||
|
||||
try {
|
||||
const responseBody = await response.text(); // Read the body as text
|
||||
@ -96,12 +102,10 @@ export async function handleOAuthSlackAuthorizationResponse(
|
||||
return data;
|
||||
}
|
||||
|
||||
// server side handler to process the oauth redirect callback
|
||||
// https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
export async function handleOAuthGoogleDriveAuthorizationResponse(
|
||||
code: string,
|
||||
state: string
|
||||
): Promise<OAuthGoogleDriveCallbackResponse> {
|
||||
): Promise<OAuthBaseCallbackResponse> {
|
||||
const url = `/api/oauth/connector/google-drive/callback?code=${encodeURIComponent(
|
||||
code
|
||||
)}&state=${encodeURIComponent(state)}`;
|
||||
@ -115,7 +119,7 @@ export async function handleOAuthGoogleDriveAuthorizationResponse(
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorDetails = `Failed to handle OAuth authorization response: ${response.status}`;
|
||||
let errorDetails = `Failed to handle OAuth Google Drive authorization response: ${response.status}`;
|
||||
|
||||
try {
|
||||
const responseBody = await response.text(); // Read the body as text
|
||||
@ -132,6 +136,137 @@ export async function handleOAuthGoogleDriveAuthorizationResponse(
|
||||
}
|
||||
|
||||
// Parse the JSON response
|
||||
const data = (await response.json()) as OAuthGoogleDriveCallbackResponse;
|
||||
const data = (await response.json()) as OAuthBaseCallbackResponse;
|
||||
return data;
|
||||
}
|
||||
|
||||
// call server side helper
|
||||
// https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps
|
||||
export async function handleOAuthConfluenceAuthorizationResponse(
|
||||
code: string,
|
||||
state: string
|
||||
): Promise<OAuthBaseCallbackResponse> {
|
||||
const url = `/api/oauth/connector/confluence/callback?code=${encodeURIComponent(
|
||||
code
|
||||
)}&state=${encodeURIComponent(state)}`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ code, state }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorDetails = `Failed to handle OAuth Confluence authorization response: ${response.status}`;
|
||||
|
||||
try {
|
||||
const responseBody = await response.text(); // Read the body as text
|
||||
errorDetails += `\nResponse Body: ${responseBody}`;
|
||||
} catch (err) {
|
||||
if (err instanceof Error) {
|
||||
errorDetails += `\nUnable to read response body: ${err.message}`;
|
||||
} else {
|
||||
errorDetails += `\nUnable to read response body: Unknown error type`;
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(errorDetails);
|
||||
}
|
||||
|
||||
// Parse the JSON response
|
||||
const data = (await response.json()) as OAuthBaseCallbackResponse;
|
||||
return data;
|
||||
}
|
||||
|
||||
export async function handleOAuthPrepareFinalization(
|
||||
connector: string,
|
||||
credential: number
|
||||
) {
|
||||
if (connector === "confluence") {
|
||||
return handleOAuthConfluencePrepareFinalization(credential);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// call server side helper
|
||||
// https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps
|
||||
export async function handleOAuthConfluencePrepareFinalization(
|
||||
credential: number
|
||||
): Promise<OAuthConfluencePrepareFinalizationResponse> {
|
||||
const url = `/api/oauth/connector/confluence/accessible-resources?credential_id=${encodeURIComponent(
|
||||
credential
|
||||
)}`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorDetails = `Failed to handle OAuth Confluence prepare finalization response: ${response.status}`;
|
||||
|
||||
try {
|
||||
const responseBody = await response.text(); // Read the body as text
|
||||
errorDetails += `\nResponse Body: ${responseBody}`;
|
||||
} catch (err) {
|
||||
if (err instanceof Error) {
|
||||
errorDetails += `\nUnable to read response body: ${err.message}`;
|
||||
} else {
|
||||
errorDetails += `\nUnable to read response body: Unknown error type`;
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(errorDetails);
|
||||
}
|
||||
|
||||
// Parse the JSON response
|
||||
const data =
|
||||
(await response.json()) as OAuthConfluencePrepareFinalizationResponse;
|
||||
return data;
|
||||
}
|
||||
|
||||
export async function handleOAuthConfluenceFinalize(
|
||||
credential_id: number,
|
||||
cloud_id: string,
|
||||
cloud_name: string,
|
||||
cloud_url: string
|
||||
): Promise<OAuthConfluenceFinalizeResponse> {
|
||||
const url = `/api/oauth/connector/confluence/finalize?credential_id=${encodeURIComponent(
|
||||
credential_id
|
||||
)}&cloud_id=${encodeURIComponent(cloud_id)}&cloud_name=${encodeURIComponent(
|
||||
cloud_name
|
||||
)}&cloud_url=${encodeURIComponent(cloud_url)}`;
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorDetails = `Failed to handle OAuth Confluence finalization response: ${response.status}`;
|
||||
|
||||
try {
|
||||
const responseBody = await response.text(); // Read the body as text
|
||||
errorDetails += `\nResponse Body: ${responseBody}`;
|
||||
} catch (err) {
|
||||
if (err instanceof Error) {
|
||||
errorDetails += `\nUnable to read response body: ${err.message}`;
|
||||
} else {
|
||||
errorDetails += `\nUnable to read response body: Unknown error type`;
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(errorDetails);
|
||||
}
|
||||
|
||||
// Parse the JSON response
|
||||
const data = (await response.json()) as OAuthConfluenceFinalizeResponse;
|
||||
return data;
|
||||
}
|
||||
|
@ -120,6 +120,7 @@ export const SOURCE_METADATA_MAP: SourceMap = {
|
||||
displayName: "Confluence",
|
||||
category: SourceCategory.Wiki,
|
||||
docs: "https://docs.onyx.app/connectors/confluence",
|
||||
oauthSupported: true,
|
||||
},
|
||||
jira: {
|
||||
icon: JiraIcon,
|
||||
|
@ -167,18 +167,36 @@ export interface OAuthPrepareAuthorizationResponse {
|
||||
url: string;
|
||||
}
|
||||
|
||||
export interface OAuthSlackCallbackResponse {
|
||||
export interface OAuthBaseCallbackResponse {
|
||||
success: boolean;
|
||||
message: string;
|
||||
team_id: string;
|
||||
authed_user_id: string;
|
||||
finalize_url: string | null;
|
||||
redirect_on_success: string;
|
||||
}
|
||||
|
||||
export interface OAuthGoogleDriveCallbackResponse {
|
||||
export interface OAuthSlackCallbackResponse extends OAuthBaseCallbackResponse {
|
||||
team_id: string;
|
||||
authed_user_id: string;
|
||||
}
|
||||
|
||||
export interface ConfluenceAccessibleResource {
|
||||
id: string;
|
||||
name: string;
|
||||
url: string;
|
||||
scopes: string[];
|
||||
avatarUrl: string;
|
||||
}
|
||||
|
||||
export interface OAuthConfluencePrepareFinalizationResponse {
|
||||
success: boolean;
|
||||
message: string;
|
||||
redirect_on_success: string;
|
||||
accessible_resources: ConfluenceAccessibleResource[];
|
||||
}
|
||||
|
||||
export interface OAuthConfluenceFinalizeResponse {
|
||||
success: boolean;
|
||||
message: string;
|
||||
redirect_url: string;
|
||||
}
|
||||
|
||||
export interface CCPairBasicInfo {
|
||||
@ -382,6 +400,7 @@ export const oauthSupportedSources: ConfigurableSources[] = [
|
||||
ValidSources.Slack,
|
||||
// NOTE: temporarily disabled until our GDrive App is approved
|
||||
// ValidSources.GoogleDrive,
|
||||
ValidSources.Confluence,
|
||||
];
|
||||
|
||||
export type OAuthSupportedSource = (typeof oauthSupportedSources)[number];
|
||||
|
Loading…
x
Reference in New Issue
Block a user