mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-02 02:30:47 +02:00
Add support for OAuth connectors that require user input (#3571)
* Add support for OAuth connectors that require user input * Cleanup * Fix linear * Small re-naming * Remove console.log
This commit is contained in:
@ -374,7 +374,6 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
|
@ -7,7 +7,8 @@ from typing import Any
|
||||
from typing import IO
|
||||
from urllib.parse import quote
|
||||
|
||||
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@ -124,6 +125,15 @@ def _process_egnyte_file(
|
||||
|
||||
|
||||
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
|
||||
egnyte_domain: str = Field(
|
||||
title="Egnyte Domain",
|
||||
description=(
|
||||
"The domain for the Egnyte instance "
|
||||
"(e.g. 'company' for company.egnyte.com)"
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
folder_path: str | None = None,
|
||||
@ -139,15 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
return DocumentSource.EGNYTE
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
def oauth_authorization_url(
|
||||
cls,
|
||||
base_domain: str,
|
||||
state: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> str:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||
|
||||
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
return (
|
||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&scope=Egnyte.filesystem"
|
||||
@ -156,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
def oauth_code_to_token(
|
||||
cls,
|
||||
base_domain: str,
|
||||
code: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_CLIENT_SECRET:
|
||||
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||
|
||||
# Exchange code for token
|
||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
|
||||
data = {
|
||||
"client_id": EGNYTE_CLIENT_ID,
|
||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||
@ -191,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
|
||||
token_data = response.json()
|
||||
return {
|
||||
"domain": EGNYTE_BASE_DOMAIN,
|
||||
"domain": oauth_kwargs.egnyte_domain,
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
@ -215,7 +236,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
"list_content": True,
|
||||
}
|
||||
|
||||
url_encoded_path = quote(path or "", safe="")
|
||||
url_encoded_path = quote(path or "")
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
||||
response = request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params
|
||||
@ -271,7 +292,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
url_encoded_path = quote(file["path"], safe="")
|
||||
url_encoded_path = quote(file["path"])
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
||||
response = request_with_retries(
|
||||
method="GET",
|
||||
|
@ -2,6 +2,8 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@ -66,6 +68,10 @@ class SlimConnector(BaseConnector):
|
||||
|
||||
|
||||
class OAuthConnector(BaseConnector):
|
||||
class AdditionalOauthKwargs(BaseModel):
|
||||
# if overridden, all fields should be str type
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
def oauth_authorization_url(
|
||||
cls,
|
||||
base_domain: str,
|
||||
state: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
def oauth_code_to_token(
|
||||
cls,
|
||||
base_domain: str,
|
||||
code: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -77,7 +77,9 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
return DocumentSource.LINEAR
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
def oauth_authorization_url(
|
||||
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
|
||||
) -> str:
|
||||
if not LINEAR_CLIENT_ID:
|
||||
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
|
||||
|
||||
@ -92,7 +94,9 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
def oauth_code_to_token(
|
||||
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
data = {
|
||||
"code": code,
|
||||
"redirect_uri": get_oauth_callback_uri(
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from typing import cast
|
||||
@ -6,7 +7,9 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
@ -28,6 +31,8 @@ router = APIRouter(prefix="/connector/oauth")
|
||||
|
||||
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
|
||||
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
|
||||
_DESIRED_RETURN_URL_KEY = "desired_return_url"
|
||||
_ADDITIONAL_KWARGS_KEY = "additional_kwargs"
|
||||
|
||||
# Cache for OAuth connectors, populated at module load time
|
||||
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
|
||||
@ -51,12 +56,36 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
|
||||
_discover_oauth_connectors()
|
||||
|
||||
|
||||
def _get_additional_kwargs(
|
||||
request: Request, connector_cls: type[OAuthConnector], args_to_ignore: list[str]
|
||||
) -> dict[str, str]:
|
||||
# get additional kwargs from request
|
||||
# e.g. anything except for desired_return_url
|
||||
additional_kwargs_dict = {
|
||||
k: v for k, v in request.query_params.items() if k not in args_to_ignore
|
||||
}
|
||||
try:
|
||||
# validate
|
||||
connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict)
|
||||
except ValidationError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
|
||||
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}"
|
||||
),
|
||||
)
|
||||
|
||||
return additional_kwargs_dict
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.get("/authorize/{source}")
|
||||
def oauth_authorize(
|
||||
request: Request,
|
||||
source: DocumentSource,
|
||||
desired_return_url: Annotated[str | None, Query()] = None,
|
||||
_: User = Depends(current_user),
|
||||
@ -71,6 +100,12 @@ def oauth_authorize(
|
||||
connector_cls = oauth_connectors[source]
|
||||
base_url = WEB_DOMAIN
|
||||
|
||||
# get additional kwargs from request
|
||||
# e.g. anything except for desired_return_url
|
||||
additional_kwargs = _get_additional_kwargs(
|
||||
request, connector_cls, ["desired_return_url"]
|
||||
)
|
||||
|
||||
# store state in redis
|
||||
if not desired_return_url:
|
||||
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
|
||||
@ -78,12 +113,19 @@ def oauth_authorize(
|
||||
state = str(uuid.uuid4())
|
||||
redis_client.set(
|
||||
_OAUTH_STATE_KEY_FMT.format(state=state),
|
||||
desired_return_url,
|
||||
json.dumps(
|
||||
{
|
||||
_DESIRED_RETURN_URL_KEY: desired_return_url,
|
||||
_ADDITIONAL_KWARGS_KEY: additional_kwargs,
|
||||
}
|
||||
),
|
||||
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
|
||||
)
|
||||
|
||||
return AuthorizeResponse(
|
||||
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
|
||||
redirect_url=connector_cls.oauth_authorization_url(
|
||||
base_url, state, additional_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -110,15 +152,18 @@ def oauth_callback(
|
||||
|
||||
# get state from redis
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
original_url_bytes = cast(
|
||||
oauth_state_bytes = cast(
|
||||
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
|
||||
)
|
||||
if not original_url_bytes:
|
||||
if not oauth_state_bytes:
|
||||
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
||||
original_url = original_url_bytes.decode("utf-8")
|
||||
oauth_state = json.loads(oauth_state_bytes.decode("utf-8"))
|
||||
|
||||
desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY])
|
||||
additional_kwargs = cast(dict[str, str], oauth_state[_ADDITIONAL_KWARGS_KEY])
|
||||
|
||||
base_url = WEB_DOMAIN
|
||||
token_info = connector_cls.oauth_code_to_token(base_url, code)
|
||||
token_info = connector_cls.oauth_code_to_token(base_url, code, additional_kwargs)
|
||||
|
||||
# Create a new credential with the token info
|
||||
credential_data = CredentialBase(
|
||||
@ -136,8 +181,52 @@ def oauth_callback(
|
||||
|
||||
return CallbackResponse(
|
||||
redirect_url=(
|
||||
f"{original_url}?credentialId={credential.id}"
|
||||
if "?" not in original_url
|
||||
else f"{original_url}&credentialId={credential.id}"
|
||||
f"{desired_return_url}?credentialId={credential.id}"
|
||||
if "?" not in desired_return_url
|
||||
else f"{desired_return_url}&credentialId={credential.id}"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class OAuthAdditionalKwargDescription(BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: str
|
||||
|
||||
|
||||
class OAuthDetails(BaseModel):
|
||||
oauth_enabled: bool
|
||||
additional_kwargs: list[OAuthAdditionalKwargDescription]
|
||||
|
||||
|
||||
@router.get("/details/{source}")
|
||||
def oauth_details(
|
||||
source: DocumentSource,
|
||||
_: User = Depends(current_user),
|
||||
) -> OAuthDetails:
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
return OAuthDetails(
|
||||
oauth_enabled=False,
|
||||
additional_kwargs=[],
|
||||
)
|
||||
|
||||
connector_cls = oauth_connectors[source]
|
||||
|
||||
additional_kwarg_descriptions = []
|
||||
for key, value in connector_cls.AdditionalOauthKwargs.model_json_schema()[
|
||||
"properties"
|
||||
].items():
|
||||
additional_kwarg_descriptions.append(
|
||||
OAuthAdditionalKwargDescription(
|
||||
name=key,
|
||||
display_name=value.get("title", key),
|
||||
description=value.get("description", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthDetails(
|
||||
oauth_enabled=True,
|
||||
additional_kwargs=additional_kwarg_descriptions,
|
||||
)
|
||||
|
@ -196,7 +196,6 @@ services:
|
||||
# Egnyte OAuth Configs
|
||||
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
|
||||
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
|
||||
- EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-}
|
||||
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
|
||||
# Celery Configs (defaults are set in the supervisord.conf file.
|
||||
# prefer doing that to have one source of defaults)
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
|
||||
import Title from "@/components/ui/title";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
@ -19,12 +18,12 @@ import AdvancedFormPage from "./pages/Advanced";
|
||||
import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm";
|
||||
import CreateCredential from "@/components/credentials/actions/CreateCredential";
|
||||
import ModifyCredential from "@/components/credentials/actions/ModifyCredential";
|
||||
import { ConfigurableSources, oauthSupportedSources } from "@/lib/types";
|
||||
import {
|
||||
ConfigurableSources,
|
||||
oauthSupportedSources,
|
||||
ValidSources,
|
||||
} from "@/lib/types";
|
||||
import { Credential, credentialTemplates } from "@/lib/connectors/credentials";
|
||||
Credential,
|
||||
credentialTemplates,
|
||||
OAuthDetails,
|
||||
} from "@/lib/connectors/credentials";
|
||||
import {
|
||||
ConnectionConfiguration,
|
||||
connectorConfigs,
|
||||
@ -37,7 +36,6 @@ import {
|
||||
ConnectorBase,
|
||||
} from "@/lib/connectors/connectors";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import GDriveMain from "./pages/gdrive/GoogleDrivePage";
|
||||
import { GmailMain } from "./pages/gmail/GmailPage";
|
||||
import {
|
||||
useGmailCredentials,
|
||||
@ -54,7 +52,12 @@ import {
|
||||
NEXT_PUBLIC_TEST_ENV,
|
||||
} from "@/lib/constants";
|
||||
import TemporaryLoadingModal from "@/components/TemporaryLoadingModal";
|
||||
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
|
||||
import {
|
||||
getConnectorOauthRedirectUrl,
|
||||
useOAuthDetails,
|
||||
} from "@/lib/connectors/oauth";
|
||||
import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
export interface AdvancedConfig {
|
||||
refreshFreq: number;
|
||||
pruneFreq: number;
|
||||
@ -144,7 +147,8 @@ export default function AddConnector({
|
||||
// State for managing credentials and files
|
||||
const [currentCredential, setCurrentCredential] =
|
||||
useState<Credential<any> | null>(null);
|
||||
const [createConnectorToggle, setCreateConnectorToggle] = useState(false);
|
||||
const [createCredentialFormToggle, setCreateCredentialFormToggle] =
|
||||
useState(false);
|
||||
|
||||
// Fetch credentials data
|
||||
const { data: credentials } = useSWR<Credential<any>[]>(
|
||||
@ -159,6 +163,9 @@ export default function AddConnector({
|
||||
{ refreshInterval: 5000 }
|
||||
);
|
||||
|
||||
const { data: oauthDetails, isLoading: oauthDetailsLoading } =
|
||||
useOAuthDetails(connector);
|
||||
|
||||
// Get credential template and configuration
|
||||
const credentialTemplate = credentialTemplates[connector];
|
||||
const configuration: ConnectionConfiguration = connectorConfigs[connector];
|
||||
@ -450,19 +457,33 @@ export default function AddConnector({
|
||||
onDeleteCredential={onDeleteCredential}
|
||||
onSwitch={onSwap}
|
||||
/>
|
||||
{!createConnectorToggle && (
|
||||
{!createCredentialFormToggle && (
|
||||
<div className="mt-6 flex space-x-4">
|
||||
{/* Button to pop up a form to manually enter credentials */}
|
||||
<button
|
||||
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
|
||||
onClick={async () => {
|
||||
if (oauthDetails && oauthDetails.oauth_enabled) {
|
||||
if (oauthDetails.additional_kwargs.length > 0) {
|
||||
setCreateCredentialFormToggle(true);
|
||||
} else {
|
||||
const redirectUrl =
|
||||
await getConnectorOauthRedirectUrl(connector);
|
||||
await getConnectorOauthRedirectUrl(
|
||||
connector,
|
||||
{}
|
||||
);
|
||||
// if redirect is supported, just use it
|
||||
if (redirectUrl) {
|
||||
window.location.href = redirectUrl;
|
||||
} else {
|
||||
setCreateConnectorToggle(
|
||||
setCreateCredentialFormToggle(
|
||||
(createConnectorToggle) =>
|
||||
!createConnectorToggle
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
setCreateCredentialFormToggle(
|
||||
(createConnectorToggle) =>
|
||||
!createConnectorToggle
|
||||
);
|
||||
@ -491,25 +512,42 @@ export default function AddConnector({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{createConnectorToggle && (
|
||||
{createCredentialFormToggle && (
|
||||
<Modal
|
||||
className="max-w-3xl rounded-lg"
|
||||
onOutsideClick={() => setCreateConnectorToggle(false)}
|
||||
onOutsideClick={() =>
|
||||
setCreateCredentialFormToggle(false)
|
||||
}
|
||||
>
|
||||
{oauthDetailsLoading ? (
|
||||
<Spinner />
|
||||
) : (
|
||||
<>
|
||||
<Title className="mb-2 text-lg">
|
||||
Create a {getSourceDisplayName(connector)}{" "}
|
||||
credential
|
||||
</Title>
|
||||
{oauthDetails && oauthDetails.oauth_enabled ? (
|
||||
<CreateStdOAuthCredential
|
||||
sourceType={connector}
|
||||
additionalFields={
|
||||
oauthDetails.additional_kwargs
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<CreateCredential
|
||||
close
|
||||
refresh={refresh}
|
||||
sourceType={connector}
|
||||
setPopup={setPopup}
|
||||
onSwitch={onSwap}
|
||||
onClose={() => setCreateConnectorToggle(false)}
|
||||
onClose={() =>
|
||||
setCreateCredentialFormToggle(false)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Modal>
|
||||
)}
|
||||
</>
|
||||
|
@ -28,7 +28,12 @@ import {
|
||||
ConfluenceCredentialJson,
|
||||
Credential,
|
||||
} from "@/lib/connectors/credentials";
|
||||
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
|
||||
import {
|
||||
getConnectorOauthRedirectUrl,
|
||||
useOAuthDetails,
|
||||
} from "@/lib/connectors/oauth";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential";
|
||||
|
||||
export default function CredentialSection({
|
||||
ccPair,
|
||||
@ -39,16 +44,6 @@ export default function CredentialSection({
|
||||
sourceType: ValidSources;
|
||||
refresh: () => void;
|
||||
}) {
|
||||
const makeShowCreateCredential = async () => {
|
||||
const redirectUrl = await getConnectorOauthRedirectUrl(sourceType);
|
||||
if (redirectUrl) {
|
||||
window.location.href = redirectUrl;
|
||||
} else {
|
||||
setShowModifyCredential(false);
|
||||
setShowCreateCredential(true);
|
||||
}
|
||||
};
|
||||
|
||||
const { data: credentials } = useSWR<Credential<ConfluenceCredentialJson>[]>(
|
||||
buildSimilarCredentialInfoURL(sourceType),
|
||||
errorHandlingFetcher,
|
||||
@ -59,6 +54,28 @@ export default function CredentialSection({
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 }
|
||||
);
|
||||
const { data: oauthDetails, isLoading: oauthDetailsLoading } =
|
||||
useOAuthDetails(sourceType);
|
||||
|
||||
const makeShowCreateCredential = async () => {
|
||||
if (oauthDetailsLoading || !oauthDetails) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (oauthDetails.oauth_enabled) {
|
||||
if (oauthDetails.additional_kwargs.length > 0) {
|
||||
setShowCreateCredential(true);
|
||||
} else {
|
||||
const redirectUrl = await getConnectorOauthRedirectUrl(sourceType, {});
|
||||
if (redirectUrl) {
|
||||
window.location.href = redirectUrl;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
setShowModifyCredential(false);
|
||||
setShowCreateCredential(true);
|
||||
}
|
||||
};
|
||||
|
||||
const onSwap = async (
|
||||
selectedCredential: Credential<any>,
|
||||
@ -193,6 +210,16 @@ export default function CredentialSection({
|
||||
className="max-w-3xl rounded-lg"
|
||||
title={`Create ${getSourceDisplayName(sourceType)} Credential`}
|
||||
>
|
||||
{oauthDetailsLoading ? (
|
||||
<Spinner />
|
||||
) : (
|
||||
<>
|
||||
{oauthDetails && oauthDetails.oauth_enabled ? (
|
||||
<CreateStdOAuthCredential
|
||||
sourceType={sourceType}
|
||||
additionalFields={oauthDetails.additional_kwargs}
|
||||
/>
|
||||
) : (
|
||||
<CreateCredential
|
||||
sourceType={sourceType}
|
||||
swapConnector={ccPair.connector}
|
||||
@ -200,6 +227,9 @@ export default function CredentialSection({
|
||||
onSwap={onSwap}
|
||||
onClose={closeCreateCredential}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Modal>
|
||||
)}
|
||||
</div>
|
||||
|
@ -0,0 +1,88 @@
|
||||
import * as Yup from "yup";
|
||||
|
||||
import React from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { Form, Formik, FormikHelpers } from "formik";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
|
||||
import { OAuthAdditionalKwargDescription } from "@/lib/connectors/credentials";
|
||||
|
||||
type formType = {
|
||||
[key: string]: any; // For additional credential fields
|
||||
};
|
||||
|
||||
export function CreateStdOAuthCredential({
|
||||
sourceType,
|
||||
additionalFields,
|
||||
}: {
|
||||
// Source information
|
||||
sourceType: ValidSources;
|
||||
|
||||
additionalFields: OAuthAdditionalKwargDescription[];
|
||||
}) {
|
||||
const handleSubmit = async (
|
||||
values: formType,
|
||||
formikHelpers: FormikHelpers<formType>
|
||||
) => {
|
||||
const { setSubmitting, validateForm } = formikHelpers;
|
||||
|
||||
const errors = await validateForm(values);
|
||||
if (Object.keys(errors).length > 0) {
|
||||
formikHelpers.setErrors(errors);
|
||||
return;
|
||||
}
|
||||
|
||||
setSubmitting(true);
|
||||
formikHelpers.setSubmitting(true);
|
||||
|
||||
const redirectUrl = await getConnectorOauthRedirectUrl(sourceType, values);
|
||||
|
||||
if (!redirectUrl) {
|
||||
throw new Error("No redirect URL found for OAuth connector");
|
||||
}
|
||||
|
||||
window.location.href = redirectUrl;
|
||||
};
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={
|
||||
{
|
||||
...Object.fromEntries(additionalFields.map((field) => [field, ""])),
|
||||
} as formType
|
||||
}
|
||||
validationSchema={Yup.object().shape({
|
||||
...Object.fromEntries(
|
||||
additionalFields.map((field) => [field.name, Yup.string().required()])
|
||||
),
|
||||
})}
|
||||
onSubmit={(values, formikHelpers) => {
|
||||
handleSubmit(values, formikHelpers);
|
||||
}}
|
||||
>
|
||||
{() => (
|
||||
<Form className="w-full flex items-stretch">
|
||||
<CardSection className="w-full !border-0 mt-4 flex flex-col gap-y-6">
|
||||
{additionalFields.map((field) => (
|
||||
<TextFormField
|
||||
key={field.name}
|
||||
name={field.name}
|
||||
label={field.display_name}
|
||||
subtext={field.description}
|
||||
type="text"
|
||||
/>
|
||||
))}
|
||||
|
||||
<div className="flex w-full">
|
||||
<Button type="submit" className="flex text-sm">
|
||||
Create
|
||||
</Button>
|
||||
</div>
|
||||
</CardSection>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
@ -1,5 +1,16 @@
|
||||
import { ValidSources } from "../types";
|
||||
|
||||
export interface OAuthAdditionalKwargDescription {
|
||||
name: string;
|
||||
display_name: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface OAuthDetails {
|
||||
oauth_enabled: boolean;
|
||||
additional_kwargs: OAuthAdditionalKwargDescription[];
|
||||
}
|
||||
|
||||
export interface CredentialBase<T> {
|
||||
credential_json: T;
|
||||
admin_public: boolean;
|
||||
|
@ -1,12 +1,18 @@
|
||||
import useSWR from "swr";
|
||||
import { ValidSources } from "../types";
|
||||
import { OAuthDetails } from "./credentials";
|
||||
import { errorHandlingFetcher } from "../fetcher";
|
||||
|
||||
export async function getConnectorOauthRedirectUrl(
|
||||
connector: ValidSources
|
||||
connector: ValidSources,
|
||||
additional_kwargs: Record<string, string>
|
||||
): Promise<string | null> {
|
||||
const queryParams = new URLSearchParams({
|
||||
desired_return_url: window.location.href,
|
||||
...additional_kwargs,
|
||||
});
|
||||
const response = await fetch(
|
||||
`/api/connector/oauth/authorize/${connector}?desired_return_url=${encodeURIComponent(
|
||||
window.location.href
|
||||
)}`
|
||||
`/api/connector/oauth/authorize/${connector}?${queryParams.toString()}`
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
@ -17,3 +23,10 @@ export async function getConnectorOauthRedirectUrl(
|
||||
const data = await response.json();
|
||||
return data.redirect_url as string;
|
||||
}
|
||||
|
||||
export function useOAuthDetails(sourceType: ValidSources) {
|
||||
return useSWR<OAuthDetails>(
|
||||
`/api/connector/oauth/details/${sourceType}`,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
}
|
||||
|
Reference in New Issue
Block a user