Add support for OAuth connectors that require user input (#3571)

* Add support for OAuth connectors that require user input

* Cleanup

* Fix linear

* Small re-naming

* Remove console.log
This commit is contained in:
Chris Weaver
2024-12-31 18:03:33 -08:00
committed by GitHub
parent ccd3983802
commit d64464ca7c
11 changed files with 389 additions and 81 deletions

View File

@ -374,7 +374,6 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
# Egnyte specific configs
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")

View File

@ -7,7 +7,8 @@ from typing import Any
from typing import IO
from urllib.parse import quote
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
from pydantic import Field
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@ -124,6 +125,15 @@ def _process_egnyte_file(
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
egnyte_domain: str = Field(
title="Egnyte Domain",
description=(
"The domain for the Egnyte instance "
"(e.g. 'company' for company.egnyte.com)"
),
)
def __init__(
self,
folder_path: str | None = None,
@ -139,15 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
return DocumentSource.EGNYTE
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&scope=Egnyte.filesystem"
@ -156,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
)
@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_CLIENT_SECRET:
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
@ -191,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
token_data = response.json()
return {
"domain": EGNYTE_BASE_DOMAIN,
"domain": oauth_kwargs.egnyte_domain,
"access_token": token_data["access_token"],
}
@ -215,7 +236,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
"list_content": True,
}
url_encoded_path = quote(path or "", safe="")
url_encoded_path = quote(path or "")
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
response = request_with_retries(
method="GET", url=url, headers=headers, params=params
@ -271,7 +292,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
headers = {
"Authorization": f"Bearer {self.access_token}",
}
url_encoded_path = quote(file["path"], safe="")
url_encoded_path = quote(file["path"])
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
response = request_with_retries(
method="GET",

View File

@ -2,6 +2,8 @@ import abc
from collections.abc import Iterator
from typing import Any
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
@ -66,6 +68,10 @@ class SlimConnector(BaseConnector):
class OAuthConnector(BaseConnector):
class AdditionalOauthKwargs(BaseModel):
# if overridden, all fields should be str type
pass
@classmethod
@abc.abstractmethod
def oauth_id(cls) -> DocumentSource:
@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector):
@classmethod
@abc.abstractmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
raise NotImplementedError

View File

@ -77,7 +77,9 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
return DocumentSource.LINEAR
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
) -> str:
if not LINEAR_CLIENT_ID:
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
@ -92,7 +94,9 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
)
@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
) -> dict[str, Any]:
data = {
"code": code,
"redirect_uri": get_oauth_callback_uri(

View File

@ -1,3 +1,4 @@
import json
import uuid
from typing import Annotated
from typing import cast
@ -6,7 +7,9 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
@ -28,6 +31,8 @@ router = APIRouter(prefix="/connector/oauth")
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
_DESIRED_RETURN_URL_KEY = "desired_return_url"
_ADDITIONAL_KWARGS_KEY = "additional_kwargs"
# Cache for OAuth connectors, populated at module load time
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
@ -51,12 +56,36 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
_discover_oauth_connectors()
def _get_additional_kwargs(
request: Request, connector_cls: type[OAuthConnector], args_to_ignore: list[str]
) -> dict[str, str]:
# get additional kwargs from request
# e.g. anything except for desired_return_url
additional_kwargs_dict = {
k: v for k, v in request.query_params.items() if k not in args_to_ignore
}
try:
# validate
connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict)
except ValidationError:
raise HTTPException(
status_code=400,
detail=(
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}"
),
)
return additional_kwargs_dict
class AuthorizeResponse(BaseModel):
redirect_url: str
@router.get("/authorize/{source}")
def oauth_authorize(
request: Request,
source: DocumentSource,
desired_return_url: Annotated[str | None, Query()] = None,
_: User = Depends(current_user),
@ -71,6 +100,12 @@ def oauth_authorize(
connector_cls = oauth_connectors[source]
base_url = WEB_DOMAIN
# get additional kwargs from request
# e.g. anything except for desired_return_url
additional_kwargs = _get_additional_kwargs(
request, connector_cls, ["desired_return_url"]
)
# store state in redis
if not desired_return_url:
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
@ -78,12 +113,19 @@ def oauth_authorize(
state = str(uuid.uuid4())
redis_client.set(
_OAUTH_STATE_KEY_FMT.format(state=state),
desired_return_url,
json.dumps(
{
_DESIRED_RETURN_URL_KEY: desired_return_url,
_ADDITIONAL_KWARGS_KEY: additional_kwargs,
}
),
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
)
return AuthorizeResponse(
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
redirect_url=connector_cls.oauth_authorization_url(
base_url, state, additional_kwargs
)
)
@ -110,15 +152,18 @@ def oauth_callback(
# get state from redis
redis_client = get_redis_client(tenant_id=tenant_id)
original_url_bytes = cast(
oauth_state_bytes = cast(
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
)
if not original_url_bytes:
if not oauth_state_bytes:
raise HTTPException(status_code=400, detail="Invalid OAuth state")
original_url = original_url_bytes.decode("utf-8")
oauth_state = json.loads(oauth_state_bytes.decode("utf-8"))
desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY])
additional_kwargs = cast(dict[str, str], oauth_state[_ADDITIONAL_KWARGS_KEY])
base_url = WEB_DOMAIN
token_info = connector_cls.oauth_code_to_token(base_url, code)
token_info = connector_cls.oauth_code_to_token(base_url, code, additional_kwargs)
# Create a new credential with the token info
credential_data = CredentialBase(
@ -136,8 +181,52 @@ def oauth_callback(
return CallbackResponse(
redirect_url=(
f"{original_url}?credentialId={credential.id}"
if "?" not in original_url
else f"{original_url}&credentialId={credential.id}"
f"{desired_return_url}?credentialId={credential.id}"
if "?" not in desired_return_url
else f"{desired_return_url}&credentialId={credential.id}"
)
)
class OAuthAdditionalKwargDescription(BaseModel):
name: str
display_name: str
description: str
class OAuthDetails(BaseModel):
oauth_enabled: bool
additional_kwargs: list[OAuthAdditionalKwargDescription]
@router.get("/details/{source}")
def oauth_details(
source: DocumentSource,
_: User = Depends(current_user),
) -> OAuthDetails:
oauth_connectors = _discover_oauth_connectors()
if source not in oauth_connectors:
return OAuthDetails(
oauth_enabled=False,
additional_kwargs=[],
)
connector_cls = oauth_connectors[source]
additional_kwarg_descriptions = []
for key, value in connector_cls.AdditionalOauthKwargs.model_json_schema()[
"properties"
].items():
additional_kwarg_descriptions.append(
OAuthAdditionalKwargDescription(
name=key,
display_name=value.get("title", key),
description=value.get("description", ""),
)
)
return OAuthDetails(
oauth_enabled=True,
additional_kwargs=additional_kwarg_descriptions,
)

View File

@ -196,7 +196,6 @@ services:
# Egnyte OAuth Configs
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
- EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-}
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
# Celery Configs (defaults are set in the supervisord.conf file.
# prefer doing that to have one source of defaults)

View File

@ -2,7 +2,6 @@
import { errorHandlingFetcher } from "@/lib/fetcher";
import useSWR, { mutate } from "swr";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import Title from "@/components/ui/title";
import { AdminPageTitle } from "@/components/admin/Title";
@ -19,12 +18,12 @@ import AdvancedFormPage from "./pages/Advanced";
import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm";
import CreateCredential from "@/components/credentials/actions/CreateCredential";
import ModifyCredential from "@/components/credentials/actions/ModifyCredential";
import { ConfigurableSources, oauthSupportedSources } from "@/lib/types";
import {
ConfigurableSources,
oauthSupportedSources,
ValidSources,
} from "@/lib/types";
import { Credential, credentialTemplates } from "@/lib/connectors/credentials";
Credential,
credentialTemplates,
OAuthDetails,
} from "@/lib/connectors/credentials";
import {
ConnectionConfiguration,
connectorConfigs,
@ -37,7 +36,6 @@ import {
ConnectorBase,
} from "@/lib/connectors/connectors";
import { Modal } from "@/components/Modal";
import GDriveMain from "./pages/gdrive/GoogleDrivePage";
import { GmailMain } from "./pages/gmail/GmailPage";
import {
useGmailCredentials,
@ -54,7 +52,12 @@ import {
NEXT_PUBLIC_TEST_ENV,
} from "@/lib/constants";
import TemporaryLoadingModal from "@/components/TemporaryLoadingModal";
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
import {
getConnectorOauthRedirectUrl,
useOAuthDetails,
} from "@/lib/connectors/oauth";
import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential";
import { Spinner } from "@/components/Spinner";
export interface AdvancedConfig {
refreshFreq: number;
pruneFreq: number;
@ -144,7 +147,8 @@ export default function AddConnector({
// State for managing credentials and files
const [currentCredential, setCurrentCredential] =
useState<Credential<any> | null>(null);
const [createConnectorToggle, setCreateConnectorToggle] = useState(false);
const [createCredentialFormToggle, setCreateCredentialFormToggle] =
useState(false);
// Fetch credentials data
const { data: credentials } = useSWR<Credential<any>[]>(
@ -159,6 +163,9 @@ export default function AddConnector({
{ refreshInterval: 5000 }
);
const { data: oauthDetails, isLoading: oauthDetailsLoading } =
useOAuthDetails(connector);
// Get credential template and configuration
const credentialTemplate = credentialTemplates[connector];
const configuration: ConnectionConfiguration = connectorConfigs[connector];
@ -450,19 +457,33 @@ export default function AddConnector({
onDeleteCredential={onDeleteCredential}
onSwitch={onSwap}
/>
{!createConnectorToggle && (
{!createCredentialFormToggle && (
<div className="mt-6 flex space-x-4">
{/* Button to pop up a form to manually enter credentials */}
<button
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
onClick={async () => {
if (oauthDetails && oauthDetails.oauth_enabled) {
if (oauthDetails.additional_kwargs.length > 0) {
setCreateCredentialFormToggle(true);
} else {
const redirectUrl =
await getConnectorOauthRedirectUrl(connector);
await getConnectorOauthRedirectUrl(
connector,
{}
);
// if redirect is supported, just use it
if (redirectUrl) {
window.location.href = redirectUrl;
} else {
setCreateConnectorToggle(
setCreateCredentialFormToggle(
(createConnectorToggle) =>
!createConnectorToggle
);
}
}
} else {
setCreateCredentialFormToggle(
(createConnectorToggle) =>
!createConnectorToggle
);
@ -491,25 +512,42 @@ export default function AddConnector({
</div>
)}
{createConnectorToggle && (
{createCredentialFormToggle && (
<Modal
className="max-w-3xl rounded-lg"
onOutsideClick={() => setCreateConnectorToggle(false)}
onOutsideClick={() =>
setCreateCredentialFormToggle(false)
}
>
{oauthDetailsLoading ? (
<Spinner />
) : (
<>
<Title className="mb-2 text-lg">
Create a {getSourceDisplayName(connector)}{" "}
credential
</Title>
{oauthDetails && oauthDetails.oauth_enabled ? (
<CreateStdOAuthCredential
sourceType={connector}
additionalFields={
oauthDetails.additional_kwargs
}
/>
) : (
<CreateCredential
close
refresh={refresh}
sourceType={connector}
setPopup={setPopup}
onSwitch={onSwap}
onClose={() => setCreateConnectorToggle(false)}
onClose={() =>
setCreateCredentialFormToggle(false)
}
/>
)}
</>
)}
</Modal>
)}
</>

View File

@ -28,7 +28,12 @@ import {
ConfluenceCredentialJson,
Credential,
} from "@/lib/connectors/credentials";
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
import {
getConnectorOauthRedirectUrl,
useOAuthDetails,
} from "@/lib/connectors/oauth";
import { Spinner } from "@/components/Spinner";
import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential";
export default function CredentialSection({
ccPair,
@ -39,16 +44,6 @@ export default function CredentialSection({
sourceType: ValidSources;
refresh: () => void;
}) {
const makeShowCreateCredential = async () => {
const redirectUrl = await getConnectorOauthRedirectUrl(sourceType);
if (redirectUrl) {
window.location.href = redirectUrl;
} else {
setShowModifyCredential(false);
setShowCreateCredential(true);
}
};
const { data: credentials } = useSWR<Credential<ConfluenceCredentialJson>[]>(
buildSimilarCredentialInfoURL(sourceType),
errorHandlingFetcher,
@ -59,6 +54,28 @@ export default function CredentialSection({
errorHandlingFetcher,
{ refreshInterval: 5000 }
);
const { data: oauthDetails, isLoading: oauthDetailsLoading } =
useOAuthDetails(sourceType);
const makeShowCreateCredential = async () => {
if (oauthDetailsLoading || !oauthDetails) {
return;
}
if (oauthDetails.oauth_enabled) {
if (oauthDetails.additional_kwargs.length > 0) {
setShowCreateCredential(true);
} else {
const redirectUrl = await getConnectorOauthRedirectUrl(sourceType, {});
if (redirectUrl) {
window.location.href = redirectUrl;
}
}
} else {
setShowModifyCredential(false);
setShowCreateCredential(true);
}
};
const onSwap = async (
selectedCredential: Credential<any>,
@ -193,6 +210,16 @@ export default function CredentialSection({
className="max-w-3xl rounded-lg"
title={`Create ${getSourceDisplayName(sourceType)} Credential`}
>
{oauthDetailsLoading ? (
<Spinner />
) : (
<>
{oauthDetails && oauthDetails.oauth_enabled ? (
<CreateStdOAuthCredential
sourceType={sourceType}
additionalFields={oauthDetails.additional_kwargs}
/>
) : (
<CreateCredential
sourceType={sourceType}
swapConnector={ccPair.connector}
@ -200,6 +227,9 @@ export default function CredentialSection({
onSwap={onSwap}
onClose={closeCreateCredential}
/>
)}
</>
)}
</Modal>
)}
</div>

View File

@ -0,0 +1,88 @@
import * as Yup from "yup";
import React from "react";
import { Button } from "@/components/ui/button";
import { ValidSources } from "@/lib/types";
import { TextFormField } from "@/components/admin/connectors/Field";
import { Form, Formik, FormikHelpers } from "formik";
import CardSection from "@/components/admin/CardSection";
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
import { OAuthAdditionalKwargDescription } from "@/lib/connectors/credentials";
type formType = {
[key: string]: any; // For additional credential fields
};
export function CreateStdOAuthCredential({
sourceType,
additionalFields,
}: {
// Source information
sourceType: ValidSources;
additionalFields: OAuthAdditionalKwargDescription[];
}) {
const handleSubmit = async (
values: formType,
formikHelpers: FormikHelpers<formType>
) => {
const { setSubmitting, validateForm } = formikHelpers;
const errors = await validateForm(values);
if (Object.keys(errors).length > 0) {
formikHelpers.setErrors(errors);
return;
}
setSubmitting(true);
formikHelpers.setSubmitting(true);
const redirectUrl = await getConnectorOauthRedirectUrl(sourceType, values);
if (!redirectUrl) {
throw new Error("No redirect URL found for OAuth connector");
}
window.location.href = redirectUrl;
};
return (
<Formik
initialValues={
{
...Object.fromEntries(additionalFields.map((field) => [field, ""])),
} as formType
}
validationSchema={Yup.object().shape({
...Object.fromEntries(
additionalFields.map((field) => [field.name, Yup.string().required()])
),
})}
onSubmit={(values, formikHelpers) => {
handleSubmit(values, formikHelpers);
}}
>
{() => (
<Form className="w-full flex items-stretch">
<CardSection className="w-full !border-0 mt-4 flex flex-col gap-y-6">
{additionalFields.map((field) => (
<TextFormField
key={field.name}
name={field.name}
label={field.display_name}
subtext={field.description}
type="text"
/>
))}
<div className="flex w-full">
<Button type="submit" className="flex text-sm">
Create
</Button>
</div>
</CardSection>
</Form>
)}
</Formik>
);
}

View File

@ -1,5 +1,16 @@
import { ValidSources } from "../types";
export interface OAuthAdditionalKwargDescription {
name: string;
display_name: string;
description: string;
}
export interface OAuthDetails {
oauth_enabled: boolean;
additional_kwargs: OAuthAdditionalKwargDescription[];
}
export interface CredentialBase<T> {
credential_json: T;
admin_public: boolean;

View File

@ -1,12 +1,18 @@
import useSWR from "swr";
import { ValidSources } from "../types";
import { OAuthDetails } from "./credentials";
import { errorHandlingFetcher } from "../fetcher";
export async function getConnectorOauthRedirectUrl(
connector: ValidSources
connector: ValidSources,
additional_kwargs: Record<string, string>
): Promise<string | null> {
const queryParams = new URLSearchParams({
desired_return_url: window.location.href,
...additional_kwargs,
});
const response = await fetch(
`/api/connector/oauth/authorize/${connector}?desired_return_url=${encodeURIComponent(
window.location.href
)}`
`/api/connector/oauth/authorize/${connector}?${queryParams.toString()}`
);
if (!response.ok) {
@ -17,3 +23,10 @@ export async function getConnectorOauthRedirectUrl(
const data = await response.json();
return data.redirect_url as string;
}
export function useOAuthDetails(sourceType: ValidSources) {
return useSWR<OAuthDetails>(
`/api/connector/oauth/details/${sourceType}`,
errorHandlingFetcher
);
}