mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-04 11:41:04 +02:00
Add support for OAuth connectors that require user input (#3571)
* Add support for OAuth connectors that require user input * Cleanup * Fix linear * Small re-naming * Remove console.log
This commit is contained in:
@ -374,7 +374,6 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
|||||||
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
|
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
|
||||||
|
|
||||||
# Egnyte specific configs
|
# Egnyte specific configs
|
||||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
|
||||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||||
|
|
||||||
|
@ -7,7 +7,8 @@ from typing import Any
|
|||||||
from typing import IO
|
from typing import IO
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
|
from pydantic import Field
|
||||||
|
|
||||||
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
|
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
|
||||||
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
|
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
@ -124,6 +125,15 @@ def _process_egnyte_file(
|
|||||||
|
|
||||||
|
|
||||||
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||||
|
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
|
||||||
|
egnyte_domain: str = Field(
|
||||||
|
title="Egnyte Domain",
|
||||||
|
description=(
|
||||||
|
"The domain for the Egnyte instance "
|
||||||
|
"(e.g. 'company' for company.egnyte.com)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
folder_path: str | None = None,
|
folder_path: str | None = None,
|
||||||
@ -139,15 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
return DocumentSource.EGNYTE
|
return DocumentSource.EGNYTE
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
def oauth_authorization_url(
|
||||||
|
cls,
|
||||||
|
base_domain: str,
|
||||||
|
state: str,
|
||||||
|
additional_kwargs: dict[str, str],
|
||||||
|
) -> str:
|
||||||
if not EGNYTE_CLIENT_ID:
|
if not EGNYTE_CLIENT_ID:
|
||||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||||
if not EGNYTE_BASE_DOMAIN:
|
|
||||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||||
|
|
||||||
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||||
return (
|
return (
|
||||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||||
f"&redirect_uri={callback_uri}"
|
f"&redirect_uri={callback_uri}"
|
||||||
f"&scope=Egnyte.filesystem"
|
f"&scope=Egnyte.filesystem"
|
||||||
@ -156,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
def oauth_code_to_token(
|
||||||
|
cls,
|
||||||
|
base_domain: str,
|
||||||
|
code: str,
|
||||||
|
additional_kwargs: dict[str, str],
|
||||||
|
) -> dict[str, Any]:
|
||||||
if not EGNYTE_CLIENT_ID:
|
if not EGNYTE_CLIENT_ID:
|
||||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||||
if not EGNYTE_CLIENT_SECRET:
|
if not EGNYTE_CLIENT_SECRET:
|
||||||
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
||||||
if not EGNYTE_BASE_DOMAIN:
|
|
||||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||||
|
|
||||||
# Exchange code for token
|
# Exchange code for token
|
||||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||||
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"client_id": EGNYTE_CLIENT_ID,
|
"client_id": EGNYTE_CLIENT_ID,
|
||||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||||
@ -191,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
|
|
||||||
token_data = response.json()
|
token_data = response.json()
|
||||||
return {
|
return {
|
||||||
"domain": EGNYTE_BASE_DOMAIN,
|
"domain": oauth_kwargs.egnyte_domain,
|
||||||
"access_token": token_data["access_token"],
|
"access_token": token_data["access_token"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +236,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
"list_content": True,
|
"list_content": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
url_encoded_path = quote(path or "", safe="")
|
url_encoded_path = quote(path or "")
|
||||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
||||||
response = request_with_retries(
|
response = request_with_retries(
|
||||||
method="GET", url=url, headers=headers, params=params
|
method="GET", url=url, headers=headers, params=params
|
||||||
@ -271,7 +292,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.access_token}",
|
"Authorization": f"Bearer {self.access_token}",
|
||||||
}
|
}
|
||||||
url_encoded_path = quote(file["path"], safe="")
|
url_encoded_path = quote(file["path"])
|
||||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
||||||
response = request_with_retries(
|
response = request_with_retries(
|
||||||
method="GET",
|
method="GET",
|
||||||
|
@ -2,6 +2,8 @@ import abc
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
from onyx.connectors.models import Document
|
from onyx.connectors.models import Document
|
||||||
from onyx.connectors.models import SlimDocument
|
from onyx.connectors.models import SlimDocument
|
||||||
@ -66,6 +68,10 @@ class SlimConnector(BaseConnector):
|
|||||||
|
|
||||||
|
|
||||||
class OAuthConnector(BaseConnector):
|
class OAuthConnector(BaseConnector):
|
||||||
|
class AdditionalOauthKwargs(BaseModel):
|
||||||
|
# if overridden, all fields should be str type
|
||||||
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def oauth_id(cls) -> DocumentSource:
|
def oauth_id(cls) -> DocumentSource:
|
||||||
@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
def oauth_authorization_url(
|
||||||
|
cls,
|
||||||
|
base_domain: str,
|
||||||
|
state: str,
|
||||||
|
additional_kwargs: dict[str, str],
|
||||||
|
) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
def oauth_code_to_token(
|
||||||
|
cls,
|
||||||
|
base_domain: str,
|
||||||
|
code: str,
|
||||||
|
additional_kwargs: dict[str, str],
|
||||||
|
) -> dict[str, Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,7 +77,9 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
return DocumentSource.LINEAR
|
return DocumentSource.LINEAR
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
def oauth_authorization_url(
|
||||||
|
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
|
||||||
|
) -> str:
|
||||||
if not LINEAR_CLIENT_ID:
|
if not LINEAR_CLIENT_ID:
|
||||||
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
|
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
|
||||||
|
|
||||||
@ -92,7 +94,9 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
def oauth_code_to_token(
|
||||||
|
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
|
||||||
|
) -> dict[str, Any]:
|
||||||
data = {
|
data = {
|
||||||
"code": code,
|
"code": code,
|
||||||
"redirect_uri": get_oauth_callback_uri(
|
"redirect_uri": get_oauth_callback_uri(
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from typing import cast
|
from typing import cast
|
||||||
@ -6,7 +7,9 @@ from fastapi import APIRouter
|
|||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi import Query
|
from fastapi import Query
|
||||||
|
from fastapi import Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from pydantic import ValidationError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from onyx.auth.users import current_user
|
from onyx.auth.users import current_user
|
||||||
@ -28,6 +31,8 @@ router = APIRouter(prefix="/connector/oauth")
|
|||||||
|
|
||||||
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
|
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
|
||||||
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
|
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
|
||||||
|
_DESIRED_RETURN_URL_KEY = "desired_return_url"
|
||||||
|
_ADDITIONAL_KWARGS_KEY = "additional_kwargs"
|
||||||
|
|
||||||
# Cache for OAuth connectors, populated at module load time
|
# Cache for OAuth connectors, populated at module load time
|
||||||
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
|
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
|
||||||
@ -51,12 +56,36 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
|
|||||||
_discover_oauth_connectors()
|
_discover_oauth_connectors()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_additional_kwargs(
|
||||||
|
request: Request, connector_cls: type[OAuthConnector], args_to_ignore: list[str]
|
||||||
|
) -> dict[str, str]:
|
||||||
|
# get additional kwargs from request
|
||||||
|
# e.g. anything except for desired_return_url
|
||||||
|
additional_kwargs_dict = {
|
||||||
|
k: v for k, v in request.query_params.items() if k not in args_to_ignore
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
# validate
|
||||||
|
connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict)
|
||||||
|
except ValidationError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=(
|
||||||
|
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
|
||||||
|
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return additional_kwargs_dict
|
||||||
|
|
||||||
|
|
||||||
class AuthorizeResponse(BaseModel):
|
class AuthorizeResponse(BaseModel):
|
||||||
redirect_url: str
|
redirect_url: str
|
||||||
|
|
||||||
|
|
||||||
@router.get("/authorize/{source}")
|
@router.get("/authorize/{source}")
|
||||||
def oauth_authorize(
|
def oauth_authorize(
|
||||||
|
request: Request,
|
||||||
source: DocumentSource,
|
source: DocumentSource,
|
||||||
desired_return_url: Annotated[str | None, Query()] = None,
|
desired_return_url: Annotated[str | None, Query()] = None,
|
||||||
_: User = Depends(current_user),
|
_: User = Depends(current_user),
|
||||||
@ -71,6 +100,12 @@ def oauth_authorize(
|
|||||||
connector_cls = oauth_connectors[source]
|
connector_cls = oauth_connectors[source]
|
||||||
base_url = WEB_DOMAIN
|
base_url = WEB_DOMAIN
|
||||||
|
|
||||||
|
# get additional kwargs from request
|
||||||
|
# e.g. anything except for desired_return_url
|
||||||
|
additional_kwargs = _get_additional_kwargs(
|
||||||
|
request, connector_cls, ["desired_return_url"]
|
||||||
|
)
|
||||||
|
|
||||||
# store state in redis
|
# store state in redis
|
||||||
if not desired_return_url:
|
if not desired_return_url:
|
||||||
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
|
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
|
||||||
@ -78,12 +113,19 @@ def oauth_authorize(
|
|||||||
state = str(uuid.uuid4())
|
state = str(uuid.uuid4())
|
||||||
redis_client.set(
|
redis_client.set(
|
||||||
_OAUTH_STATE_KEY_FMT.format(state=state),
|
_OAUTH_STATE_KEY_FMT.format(state=state),
|
||||||
desired_return_url,
|
json.dumps(
|
||||||
|
{
|
||||||
|
_DESIRED_RETURN_URL_KEY: desired_return_url,
|
||||||
|
_ADDITIONAL_KWARGS_KEY: additional_kwargs,
|
||||||
|
}
|
||||||
|
),
|
||||||
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
|
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AuthorizeResponse(
|
return AuthorizeResponse(
|
||||||
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
|
redirect_url=connector_cls.oauth_authorization_url(
|
||||||
|
base_url, state, additional_kwargs
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -110,15 +152,18 @@ def oauth_callback(
|
|||||||
|
|
||||||
# get state from redis
|
# get state from redis
|
||||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||||
original_url_bytes = cast(
|
oauth_state_bytes = cast(
|
||||||
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
|
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
|
||||||
)
|
)
|
||||||
if not original_url_bytes:
|
if not oauth_state_bytes:
|
||||||
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
||||||
original_url = original_url_bytes.decode("utf-8")
|
oauth_state = json.loads(oauth_state_bytes.decode("utf-8"))
|
||||||
|
|
||||||
|
desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY])
|
||||||
|
additional_kwargs = cast(dict[str, str], oauth_state[_ADDITIONAL_KWARGS_KEY])
|
||||||
|
|
||||||
base_url = WEB_DOMAIN
|
base_url = WEB_DOMAIN
|
||||||
token_info = connector_cls.oauth_code_to_token(base_url, code)
|
token_info = connector_cls.oauth_code_to_token(base_url, code, additional_kwargs)
|
||||||
|
|
||||||
# Create a new credential with the token info
|
# Create a new credential with the token info
|
||||||
credential_data = CredentialBase(
|
credential_data = CredentialBase(
|
||||||
@ -136,8 +181,52 @@ def oauth_callback(
|
|||||||
|
|
||||||
return CallbackResponse(
|
return CallbackResponse(
|
||||||
redirect_url=(
|
redirect_url=(
|
||||||
f"{original_url}?credentialId={credential.id}"
|
f"{desired_return_url}?credentialId={credential.id}"
|
||||||
if "?" not in original_url
|
if "?" not in desired_return_url
|
||||||
else f"{original_url}&credentialId={credential.id}"
|
else f"{desired_return_url}&credentialId={credential.id}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAdditionalKwargDescription(BaseModel):
|
||||||
|
name: str
|
||||||
|
display_name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthDetails(BaseModel):
|
||||||
|
oauth_enabled: bool
|
||||||
|
additional_kwargs: list[OAuthAdditionalKwargDescription]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/details/{source}")
|
||||||
|
def oauth_details(
|
||||||
|
source: DocumentSource,
|
||||||
|
_: User = Depends(current_user),
|
||||||
|
) -> OAuthDetails:
|
||||||
|
oauth_connectors = _discover_oauth_connectors()
|
||||||
|
|
||||||
|
if source not in oauth_connectors:
|
||||||
|
return OAuthDetails(
|
||||||
|
oauth_enabled=False,
|
||||||
|
additional_kwargs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
connector_cls = oauth_connectors[source]
|
||||||
|
|
||||||
|
additional_kwarg_descriptions = []
|
||||||
|
for key, value in connector_cls.AdditionalOauthKwargs.model_json_schema()[
|
||||||
|
"properties"
|
||||||
|
].items():
|
||||||
|
additional_kwarg_descriptions.append(
|
||||||
|
OAuthAdditionalKwargDescription(
|
||||||
|
name=key,
|
||||||
|
display_name=value.get("title", key),
|
||||||
|
description=value.get("description", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return OAuthDetails(
|
||||||
|
oauth_enabled=True,
|
||||||
|
additional_kwargs=additional_kwarg_descriptions,
|
||||||
|
)
|
||||||
|
@ -196,7 +196,6 @@ services:
|
|||||||
# Egnyte OAuth Configs
|
# Egnyte OAuth Configs
|
||||||
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
|
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
|
||||||
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
|
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
|
||||||
- EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-}
|
|
||||||
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
|
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
|
||||||
# Celery Configs (defaults are set in the supervisord.conf file.
|
# Celery Configs (defaults are set in the supervisord.conf file.
|
||||||
# prefer doing that to have one source of defaults)
|
# prefer doing that to have one source of defaults)
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||||
import useSWR, { mutate } from "swr";
|
import useSWR, { mutate } from "swr";
|
||||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
|
||||||
|
|
||||||
import Title from "@/components/ui/title";
|
import Title from "@/components/ui/title";
|
||||||
import { AdminPageTitle } from "@/components/admin/Title";
|
import { AdminPageTitle } from "@/components/admin/Title";
|
||||||
@ -19,12 +18,12 @@ import AdvancedFormPage from "./pages/Advanced";
|
|||||||
import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm";
|
import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm";
|
||||||
import CreateCredential from "@/components/credentials/actions/CreateCredential";
|
import CreateCredential from "@/components/credentials/actions/CreateCredential";
|
||||||
import ModifyCredential from "@/components/credentials/actions/ModifyCredential";
|
import ModifyCredential from "@/components/credentials/actions/ModifyCredential";
|
||||||
|
import { ConfigurableSources, oauthSupportedSources } from "@/lib/types";
|
||||||
import {
|
import {
|
||||||
ConfigurableSources,
|
Credential,
|
||||||
oauthSupportedSources,
|
credentialTemplates,
|
||||||
ValidSources,
|
OAuthDetails,
|
||||||
} from "@/lib/types";
|
} from "@/lib/connectors/credentials";
|
||||||
import { Credential, credentialTemplates } from "@/lib/connectors/credentials";
|
|
||||||
import {
|
import {
|
||||||
ConnectionConfiguration,
|
ConnectionConfiguration,
|
||||||
connectorConfigs,
|
connectorConfigs,
|
||||||
@ -37,7 +36,6 @@ import {
|
|||||||
ConnectorBase,
|
ConnectorBase,
|
||||||
} from "@/lib/connectors/connectors";
|
} from "@/lib/connectors/connectors";
|
||||||
import { Modal } from "@/components/Modal";
|
import { Modal } from "@/components/Modal";
|
||||||
import GDriveMain from "./pages/gdrive/GoogleDrivePage";
|
|
||||||
import { GmailMain } from "./pages/gmail/GmailPage";
|
import { GmailMain } from "./pages/gmail/GmailPage";
|
||||||
import {
|
import {
|
||||||
useGmailCredentials,
|
useGmailCredentials,
|
||||||
@ -54,7 +52,12 @@ import {
|
|||||||
NEXT_PUBLIC_TEST_ENV,
|
NEXT_PUBLIC_TEST_ENV,
|
||||||
} from "@/lib/constants";
|
} from "@/lib/constants";
|
||||||
import TemporaryLoadingModal from "@/components/TemporaryLoadingModal";
|
import TemporaryLoadingModal from "@/components/TemporaryLoadingModal";
|
||||||
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
|
import {
|
||||||
|
getConnectorOauthRedirectUrl,
|
||||||
|
useOAuthDetails,
|
||||||
|
} from "@/lib/connectors/oauth";
|
||||||
|
import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential";
|
||||||
|
import { Spinner } from "@/components/Spinner";
|
||||||
export interface AdvancedConfig {
|
export interface AdvancedConfig {
|
||||||
refreshFreq: number;
|
refreshFreq: number;
|
||||||
pruneFreq: number;
|
pruneFreq: number;
|
||||||
@ -144,7 +147,8 @@ export default function AddConnector({
|
|||||||
// State for managing credentials and files
|
// State for managing credentials and files
|
||||||
const [currentCredential, setCurrentCredential] =
|
const [currentCredential, setCurrentCredential] =
|
||||||
useState<Credential<any> | null>(null);
|
useState<Credential<any> | null>(null);
|
||||||
const [createConnectorToggle, setCreateConnectorToggle] = useState(false);
|
const [createCredentialFormToggle, setCreateCredentialFormToggle] =
|
||||||
|
useState(false);
|
||||||
|
|
||||||
// Fetch credentials data
|
// Fetch credentials data
|
||||||
const { data: credentials } = useSWR<Credential<any>[]>(
|
const { data: credentials } = useSWR<Credential<any>[]>(
|
||||||
@ -159,6 +163,9 @@ export default function AddConnector({
|
|||||||
{ refreshInterval: 5000 }
|
{ refreshInterval: 5000 }
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const { data: oauthDetails, isLoading: oauthDetailsLoading } =
|
||||||
|
useOAuthDetails(connector);
|
||||||
|
|
||||||
// Get credential template and configuration
|
// Get credential template and configuration
|
||||||
const credentialTemplate = credentialTemplates[connector];
|
const credentialTemplate = credentialTemplates[connector];
|
||||||
const configuration: ConnectionConfiguration = connectorConfigs[connector];
|
const configuration: ConnectionConfiguration = connectorConfigs[connector];
|
||||||
@ -450,19 +457,33 @@ export default function AddConnector({
|
|||||||
onDeleteCredential={onDeleteCredential}
|
onDeleteCredential={onDeleteCredential}
|
||||||
onSwitch={onSwap}
|
onSwitch={onSwap}
|
||||||
/>
|
/>
|
||||||
{!createConnectorToggle && (
|
{!createCredentialFormToggle && (
|
||||||
<div className="mt-6 flex space-x-4">
|
<div className="mt-6 flex space-x-4">
|
||||||
{/* Button to pop up a form to manually enter credentials */}
|
{/* Button to pop up a form to manually enter credentials */}
|
||||||
<button
|
<button
|
||||||
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
|
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
|
||||||
onClick={async () => {
|
onClick={async () => {
|
||||||
const redirectUrl =
|
if (oauthDetails && oauthDetails.oauth_enabled) {
|
||||||
await getConnectorOauthRedirectUrl(connector);
|
if (oauthDetails.additional_kwargs.length > 0) {
|
||||||
// if redirect is supported, just use it
|
setCreateCredentialFormToggle(true);
|
||||||
if (redirectUrl) {
|
} else {
|
||||||
window.location.href = redirectUrl;
|
const redirectUrl =
|
||||||
|
await getConnectorOauthRedirectUrl(
|
||||||
|
connector,
|
||||||
|
{}
|
||||||
|
);
|
||||||
|
// if redirect is supported, just use it
|
||||||
|
if (redirectUrl) {
|
||||||
|
window.location.href = redirectUrl;
|
||||||
|
} else {
|
||||||
|
setCreateCredentialFormToggle(
|
||||||
|
(createConnectorToggle) =>
|
||||||
|
!createConnectorToggle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
setCreateConnectorToggle(
|
setCreateCredentialFormToggle(
|
||||||
(createConnectorToggle) =>
|
(createConnectorToggle) =>
|
||||||
!createConnectorToggle
|
!createConnectorToggle
|
||||||
);
|
);
|
||||||
@ -491,25 +512,42 @@ export default function AddConnector({
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{createConnectorToggle && (
|
{createCredentialFormToggle && (
|
||||||
<Modal
|
<Modal
|
||||||
className="max-w-3xl rounded-lg"
|
className="max-w-3xl rounded-lg"
|
||||||
onOutsideClick={() => setCreateConnectorToggle(false)}
|
onOutsideClick={() =>
|
||||||
|
setCreateCredentialFormToggle(false)
|
||||||
|
}
|
||||||
>
|
>
|
||||||
<>
|
{oauthDetailsLoading ? (
|
||||||
<Title className="mb-2 text-lg">
|
<Spinner />
|
||||||
Create a {getSourceDisplayName(connector)}{" "}
|
) : (
|
||||||
credential
|
<>
|
||||||
</Title>
|
<Title className="mb-2 text-lg">
|
||||||
<CreateCredential
|
Create a {getSourceDisplayName(connector)}{" "}
|
||||||
close
|
credential
|
||||||
refresh={refresh}
|
</Title>
|
||||||
sourceType={connector}
|
{oauthDetails && oauthDetails.oauth_enabled ? (
|
||||||
setPopup={setPopup}
|
<CreateStdOAuthCredential
|
||||||
onSwitch={onSwap}
|
sourceType={connector}
|
||||||
onClose={() => setCreateConnectorToggle(false)}
|
additionalFields={
|
||||||
/>
|
oauthDetails.additional_kwargs
|
||||||
</>
|
}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<CreateCredential
|
||||||
|
close
|
||||||
|
refresh={refresh}
|
||||||
|
sourceType={connector}
|
||||||
|
setPopup={setPopup}
|
||||||
|
onSwitch={onSwap}
|
||||||
|
onClose={() =>
|
||||||
|
setCreateCredentialFormToggle(false)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</Modal>
|
</Modal>
|
||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
|
@ -28,7 +28,12 @@ import {
|
|||||||
ConfluenceCredentialJson,
|
ConfluenceCredentialJson,
|
||||||
Credential,
|
Credential,
|
||||||
} from "@/lib/connectors/credentials";
|
} from "@/lib/connectors/credentials";
|
||||||
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
|
import {
|
||||||
|
getConnectorOauthRedirectUrl,
|
||||||
|
useOAuthDetails,
|
||||||
|
} from "@/lib/connectors/oauth";
|
||||||
|
import { Spinner } from "@/components/Spinner";
|
||||||
|
import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential";
|
||||||
|
|
||||||
export default function CredentialSection({
|
export default function CredentialSection({
|
||||||
ccPair,
|
ccPair,
|
||||||
@ -39,16 +44,6 @@ export default function CredentialSection({
|
|||||||
sourceType: ValidSources;
|
sourceType: ValidSources;
|
||||||
refresh: () => void;
|
refresh: () => void;
|
||||||
}) {
|
}) {
|
||||||
const makeShowCreateCredential = async () => {
|
|
||||||
const redirectUrl = await getConnectorOauthRedirectUrl(sourceType);
|
|
||||||
if (redirectUrl) {
|
|
||||||
window.location.href = redirectUrl;
|
|
||||||
} else {
|
|
||||||
setShowModifyCredential(false);
|
|
||||||
setShowCreateCredential(true);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const { data: credentials } = useSWR<Credential<ConfluenceCredentialJson>[]>(
|
const { data: credentials } = useSWR<Credential<ConfluenceCredentialJson>[]>(
|
||||||
buildSimilarCredentialInfoURL(sourceType),
|
buildSimilarCredentialInfoURL(sourceType),
|
||||||
errorHandlingFetcher,
|
errorHandlingFetcher,
|
||||||
@ -59,6 +54,28 @@ export default function CredentialSection({
|
|||||||
errorHandlingFetcher,
|
errorHandlingFetcher,
|
||||||
{ refreshInterval: 5000 }
|
{ refreshInterval: 5000 }
|
||||||
);
|
);
|
||||||
|
const { data: oauthDetails, isLoading: oauthDetailsLoading } =
|
||||||
|
useOAuthDetails(sourceType);
|
||||||
|
|
||||||
|
const makeShowCreateCredential = async () => {
|
||||||
|
if (oauthDetailsLoading || !oauthDetails) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (oauthDetails.oauth_enabled) {
|
||||||
|
if (oauthDetails.additional_kwargs.length > 0) {
|
||||||
|
setShowCreateCredential(true);
|
||||||
|
} else {
|
||||||
|
const redirectUrl = await getConnectorOauthRedirectUrl(sourceType, {});
|
||||||
|
if (redirectUrl) {
|
||||||
|
window.location.href = redirectUrl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
setShowModifyCredential(false);
|
||||||
|
setShowCreateCredential(true);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const onSwap = async (
|
const onSwap = async (
|
||||||
selectedCredential: Credential<any>,
|
selectedCredential: Credential<any>,
|
||||||
@ -193,13 +210,26 @@ export default function CredentialSection({
|
|||||||
className="max-w-3xl rounded-lg"
|
className="max-w-3xl rounded-lg"
|
||||||
title={`Create ${getSourceDisplayName(sourceType)} Credential`}
|
title={`Create ${getSourceDisplayName(sourceType)} Credential`}
|
||||||
>
|
>
|
||||||
<CreateCredential
|
{oauthDetailsLoading ? (
|
||||||
sourceType={sourceType}
|
<Spinner />
|
||||||
swapConnector={ccPair.connector}
|
) : (
|
||||||
setPopup={setPopup}
|
<>
|
||||||
onSwap={onSwap}
|
{oauthDetails && oauthDetails.oauth_enabled ? (
|
||||||
onClose={closeCreateCredential}
|
<CreateStdOAuthCredential
|
||||||
/>
|
sourceType={sourceType}
|
||||||
|
additionalFields={oauthDetails.additional_kwargs}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<CreateCredential
|
||||||
|
sourceType={sourceType}
|
||||||
|
swapConnector={ccPair.connector}
|
||||||
|
setPopup={setPopup}
|
||||||
|
onSwap={onSwap}
|
||||||
|
onClose={closeCreateCredential}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</Modal>
|
</Modal>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
@ -0,0 +1,88 @@
|
|||||||
|
import * as Yup from "yup";
|
||||||
|
|
||||||
|
import React from "react";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { ValidSources } from "@/lib/types";
|
||||||
|
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||||
|
import { Form, Formik, FormikHelpers } from "formik";
|
||||||
|
import CardSection from "@/components/admin/CardSection";
|
||||||
|
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
|
||||||
|
import { OAuthAdditionalKwargDescription } from "@/lib/connectors/credentials";
|
||||||
|
|
||||||
|
type formType = {
|
||||||
|
[key: string]: any; // For additional credential fields
|
||||||
|
};
|
||||||
|
|
||||||
|
export function CreateStdOAuthCredential({
|
||||||
|
sourceType,
|
||||||
|
additionalFields,
|
||||||
|
}: {
|
||||||
|
// Source information
|
||||||
|
sourceType: ValidSources;
|
||||||
|
|
||||||
|
additionalFields: OAuthAdditionalKwargDescription[];
|
||||||
|
}) {
|
||||||
|
const handleSubmit = async (
|
||||||
|
values: formType,
|
||||||
|
formikHelpers: FormikHelpers<formType>
|
||||||
|
) => {
|
||||||
|
const { setSubmitting, validateForm } = formikHelpers;
|
||||||
|
|
||||||
|
const errors = await validateForm(values);
|
||||||
|
if (Object.keys(errors).length > 0) {
|
||||||
|
formikHelpers.setErrors(errors);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setSubmitting(true);
|
||||||
|
formikHelpers.setSubmitting(true);
|
||||||
|
|
||||||
|
const redirectUrl = await getConnectorOauthRedirectUrl(sourceType, values);
|
||||||
|
|
||||||
|
if (!redirectUrl) {
|
||||||
|
throw new Error("No redirect URL found for OAuth connector");
|
||||||
|
}
|
||||||
|
|
||||||
|
window.location.href = redirectUrl;
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Formik
|
||||||
|
initialValues={
|
||||||
|
{
|
||||||
|
...Object.fromEntries(additionalFields.map((field) => [field, ""])),
|
||||||
|
} as formType
|
||||||
|
}
|
||||||
|
validationSchema={Yup.object().shape({
|
||||||
|
...Object.fromEntries(
|
||||||
|
additionalFields.map((field) => [field.name, Yup.string().required()])
|
||||||
|
),
|
||||||
|
})}
|
||||||
|
onSubmit={(values, formikHelpers) => {
|
||||||
|
handleSubmit(values, formikHelpers);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{() => (
|
||||||
|
<Form className="w-full flex items-stretch">
|
||||||
|
<CardSection className="w-full !border-0 mt-4 flex flex-col gap-y-6">
|
||||||
|
{additionalFields.map((field) => (
|
||||||
|
<TextFormField
|
||||||
|
key={field.name}
|
||||||
|
name={field.name}
|
||||||
|
label={field.display_name}
|
||||||
|
subtext={field.description}
|
||||||
|
type="text"
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
|
||||||
|
<div className="flex w-full">
|
||||||
|
<Button type="submit" className="flex text-sm">
|
||||||
|
Create
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</CardSection>
|
||||||
|
</Form>
|
||||||
|
)}
|
||||||
|
</Formik>
|
||||||
|
);
|
||||||
|
}
|
@ -1,5 +1,16 @@
|
|||||||
import { ValidSources } from "../types";
|
import { ValidSources } from "../types";
|
||||||
|
|
||||||
|
export interface OAuthAdditionalKwargDescription {
|
||||||
|
name: string;
|
||||||
|
display_name: string;
|
||||||
|
description: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface OAuthDetails {
|
||||||
|
oauth_enabled: boolean;
|
||||||
|
additional_kwargs: OAuthAdditionalKwargDescription[];
|
||||||
|
}
|
||||||
|
|
||||||
export interface CredentialBase<T> {
|
export interface CredentialBase<T> {
|
||||||
credential_json: T;
|
credential_json: T;
|
||||||
admin_public: boolean;
|
admin_public: boolean;
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
|
import useSWR from "swr";
|
||||||
import { ValidSources } from "../types";
|
import { ValidSources } from "../types";
|
||||||
|
import { OAuthDetails } from "./credentials";
|
||||||
|
import { errorHandlingFetcher } from "../fetcher";
|
||||||
|
|
||||||
export async function getConnectorOauthRedirectUrl(
|
export async function getConnectorOauthRedirectUrl(
|
||||||
connector: ValidSources
|
connector: ValidSources,
|
||||||
|
additional_kwargs: Record<string, string>
|
||||||
): Promise<string | null> {
|
): Promise<string | null> {
|
||||||
|
const queryParams = new URLSearchParams({
|
||||||
|
desired_return_url: window.location.href,
|
||||||
|
...additional_kwargs,
|
||||||
|
});
|
||||||
const response = await fetch(
|
const response = await fetch(
|
||||||
`/api/connector/oauth/authorize/${connector}?desired_return_url=${encodeURIComponent(
|
`/api/connector/oauth/authorize/${connector}?${queryParams.toString()}`
|
||||||
window.location.href
|
|
||||||
)}`
|
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
@ -17,3 +23,10 @@ export async function getConnectorOauthRedirectUrl(
|
|||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
return data.redirect_url as string;
|
return data.redirect_url as string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function useOAuthDetails(sourceType: ValidSources) {
|
||||||
|
return useSWR<OAuthDetails>(
|
||||||
|
`/api/connector/oauth/details/${sourceType}`,
|
||||||
|
errorHandlingFetcher
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user