diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 1c8be75fc8..21b74153ed 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -374,7 +374,6 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = ( CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE") # Egnyte specific configs -EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN") EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID") EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET") diff --git a/backend/onyx/connectors/egnyte/connector.py b/backend/onyx/connectors/egnyte/connector.py index 0fa82cd55e..979f5a83ea 100644 --- a/backend/onyx/connectors/egnyte/connector.py +++ b/backend/onyx/connectors/egnyte/connector.py @@ -7,7 +7,8 @@ from typing import Any from typing import IO from urllib.parse import quote -from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN +from pydantic import Field + from onyx.configs.app_configs import EGNYTE_CLIENT_ID from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET from onyx.configs.app_configs import INDEX_BATCH_SIZE @@ -124,6 +125,15 @@ def _process_egnyte_file( class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): + class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs): + egnyte_domain: str = Field( + title="Egnyte Domain", + description=( + "The domain for the Egnyte instance " + "(e.g. 'company' for company.egnyte.com)" + ), + ) + def __init__( self, folder_path: str | None = None, @@ -139,15 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): return DocumentSource.EGNYTE @classmethod - def oauth_authorization_url(cls, base_domain: str, state: str) -> str: + def oauth_authorization_url( + cls, + base_domain: str, + state: str, + additional_kwargs: dict[str, str], + ) -> str: if not EGNYTE_CLIENT_ID: raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") - if not EGNYTE_BASE_DOMAIN: - raise ValueError("EGNYTE_DOMAIN environment variable must be set") + + oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs) callback_uri = get_oauth_callback_uri(base_domain, "egnyte") return ( - f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token" + f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token" f"?client_id={EGNYTE_CLIENT_ID}" f"&redirect_uri={callback_uri}" f"&scope=Egnyte.filesystem" @@ -156,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): ) @classmethod - def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]: + def oauth_code_to_token( + cls, + base_domain: str, + code: str, + additional_kwargs: dict[str, str], + ) -> dict[str, Any]: if not EGNYTE_CLIENT_ID: raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") if not EGNYTE_CLIENT_SECRET: raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set") - if not EGNYTE_BASE_DOMAIN: - raise ValueError("EGNYTE_DOMAIN environment variable must be set") + + oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs) # Exchange code for token - url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token" + url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token" redirect_uri = get_oauth_callback_uri(base_domain, "egnyte") + data = { "client_id": EGNYTE_CLIENT_ID, "client_secret": EGNYTE_CLIENT_SECRET, @@ -191,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): token_data = response.json() return { - "domain": EGNYTE_BASE_DOMAIN, + "domain": oauth_kwargs.egnyte_domain, "access_token": token_data["access_token"], } @@ -215,7 +236,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): "list_content": True, } - url_encoded_path = quote(path or "", safe="") + url_encoded_path = quote(path or "") url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}" response = request_with_retries( method="GET", url=url, headers=headers, params=params @@ -271,7 +292,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): headers = { "Authorization": f"Bearer {self.access_token}", } - url_encoded_path = quote(file["path"], safe="") + url_encoded_path = quote(file["path"]) url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}" response = request_with_retries( method="GET", diff --git a/backend/onyx/connectors/interfaces.py b/backend/onyx/connectors/interfaces.py index e2726ee667..10e1ab21f3 100644 --- a/backend/onyx/connectors/interfaces.py +++ b/backend/onyx/connectors/interfaces.py @@ -2,6 +2,8 @@ import abc from collections.abc import Iterator from typing import Any +from pydantic import BaseModel + from onyx.configs.constants import DocumentSource from onyx.connectors.models import Document from onyx.connectors.models import SlimDocument @@ -66,6 +68,10 @@ class SlimConnector(BaseConnector): class OAuthConnector(BaseConnector): + class AdditionalOauthKwargs(BaseModel): + # if overridden, all fields should be str type + pass + @classmethod @abc.abstractmethod def oauth_id(cls) -> DocumentSource: @@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector): @classmethod @abc.abstractmethod - def oauth_authorization_url(cls, base_domain: str, state: str) -> str: + def oauth_authorization_url( + cls, + base_domain: str, + state: str, + additional_kwargs: dict[str, str], + ) -> str: raise NotImplementedError @classmethod @abc.abstractmethod - def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]: + def oauth_code_to_token( + cls, + base_domain: str, + code: str, + additional_kwargs: dict[str, str], + ) -> dict[str, Any]: raise NotImplementedError diff --git a/backend/onyx/connectors/linear/connector.py b/backend/onyx/connectors/linear/connector.py index 0bd10e91f7..07f4bce2ee 100644 --- a/backend/onyx/connectors/linear/connector.py +++ b/backend/onyx/connectors/linear/connector.py @@ -77,7 +77,9 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector): return DocumentSource.LINEAR @classmethod - def oauth_authorization_url(cls, base_domain: str, state: str) -> str: + def oauth_authorization_url( + cls, base_domain: str, state: str, additional_kwargs: dict[str, str] + ) -> str: if not LINEAR_CLIENT_ID: raise ValueError("LINEAR_CLIENT_ID environment variable must be set") @@ -92,7 +94,9 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector): ) @classmethod - def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]: + def oauth_code_to_token( + cls, base_domain: str, code: str, additional_kwargs: dict[str, str] + ) -> dict[str, Any]: data = { "code": code, "redirect_uri": get_oauth_callback_uri( diff --git a/backend/onyx/server/documents/standard_oauth.py b/backend/onyx/server/documents/standard_oauth.py index 961d0f2cb4..91cd5322ec 100644 --- a/backend/onyx/server/documents/standard_oauth.py +++ b/backend/onyx/server/documents/standard_oauth.py @@ -1,3 +1,4 @@ +import json import uuid from typing import Annotated from typing import cast @@ -6,7 +7,9 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Query +from fastapi import Request from pydantic import BaseModel +from pydantic import ValidationError from sqlalchemy.orm import Session from onyx.auth.users import current_user @@ -28,6 +31,8 @@ router = APIRouter(prefix="/connector/oauth") _OAUTH_STATE_KEY_FMT = "oauth_state:{state}" _OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes +_DESIRED_RETURN_URL_KEY = "desired_return_url" +_ADDITIONAL_KWARGS_KEY = "additional_kwargs" # Cache for OAuth connectors, populated at module load time _OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {} @@ -51,12 +56,36 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]: _discover_oauth_connectors() +def _get_additional_kwargs( + request: Request, connector_cls: type[OAuthConnector], args_to_ignore: list[str] +) -> dict[str, str]: + # get additional kwargs from request + # e.g. anything except for desired_return_url + additional_kwargs_dict = { + k: v for k, v in request.query_params.items() if k not in args_to_ignore + } + try: + # validate + connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict) + except ValidationError: + raise HTTPException( + status_code=400, + detail=( + f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected " + f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}" + ), + ) + + return additional_kwargs_dict + + class AuthorizeResponse(BaseModel): redirect_url: str @router.get("/authorize/{source}") def oauth_authorize( + request: Request, source: DocumentSource, desired_return_url: Annotated[str | None, Query()] = None, _: User = Depends(current_user), @@ -71,6 +100,12 @@ def oauth_authorize( connector_cls = oauth_connectors[source] base_url = WEB_DOMAIN + # get additional kwargs from request + # e.g. anything except for desired_return_url + additional_kwargs = _get_additional_kwargs( + request, connector_cls, ["desired_return_url"] + ) + # store state in redis if not desired_return_url: desired_return_url = f"{base_url}/admin/connectors/{source}?step=0" @@ -78,12 +113,19 @@ def oauth_authorize( state = str(uuid.uuid4()) redis_client.set( _OAUTH_STATE_KEY_FMT.format(state=state), - desired_return_url, + json.dumps( + { + _DESIRED_RETURN_URL_KEY: desired_return_url, + _ADDITIONAL_KWARGS_KEY: additional_kwargs, + } + ), ex=_OAUTH_STATE_EXPIRATION_SECONDS, ) return AuthorizeResponse( - redirect_url=connector_cls.oauth_authorization_url(base_url, state) + redirect_url=connector_cls.oauth_authorization_url( + base_url, state, additional_kwargs + ) ) @@ -110,15 +152,18 @@ def oauth_callback( # get state from redis redis_client = get_redis_client(tenant_id=tenant_id) - original_url_bytes = cast( + oauth_state_bytes = cast( bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state)) ) - if not original_url_bytes: + if not oauth_state_bytes: raise HTTPException(status_code=400, detail="Invalid OAuth state") - original_url = original_url_bytes.decode("utf-8") + oauth_state = json.loads(oauth_state_bytes.decode("utf-8")) + + desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY]) + additional_kwargs = cast(dict[str, str], oauth_state[_ADDITIONAL_KWARGS_KEY]) base_url = WEB_DOMAIN - token_info = connector_cls.oauth_code_to_token(base_url, code) + token_info = connector_cls.oauth_code_to_token(base_url, code, additional_kwargs) # Create a new credential with the token info credential_data = CredentialBase( @@ -136,8 +181,52 @@ def oauth_callback( return CallbackResponse( redirect_url=( - f"{original_url}?credentialId={credential.id}" - if "?" not in original_url - else f"{original_url}&credentialId={credential.id}" + f"{desired_return_url}?credentialId={credential.id}" + if "?" not in desired_return_url + else f"{desired_return_url}&credentialId={credential.id}" ) ) + + +class OAuthAdditionalKwargDescription(BaseModel): + name: str + display_name: str + description: str + + +class OAuthDetails(BaseModel): + oauth_enabled: bool + additional_kwargs: list[OAuthAdditionalKwargDescription] + + +@router.get("/details/{source}") +def oauth_details( + source: DocumentSource, + _: User = Depends(current_user), +) -> OAuthDetails: + oauth_connectors = _discover_oauth_connectors() + + if source not in oauth_connectors: + return OAuthDetails( + oauth_enabled=False, + additional_kwargs=[], + ) + + connector_cls = oauth_connectors[source] + + additional_kwarg_descriptions = [] + for key, value in connector_cls.AdditionalOauthKwargs.model_json_schema()[ + "properties" + ].items(): + additional_kwarg_descriptions.append( + OAuthAdditionalKwargDescription( + name=key, + display_name=value.get("title", key), + description=value.get("description", ""), + ) + ) + + return OAuthDetails( + oauth_enabled=True, + additional_kwargs=additional_kwarg_descriptions, + ) diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index f5c5463bd4..1d6530161c 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -196,7 +196,6 @@ services: # Egnyte OAuth Configs - EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-} - EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-} - - EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-} - EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-} # Celery Configs (defaults are set in the supervisord.conf file. # prefer doing that to have one source of defaults) diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index 534fae2814..f07738d3a6 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -2,7 +2,6 @@ import { errorHandlingFetcher } from "@/lib/fetcher"; import useSWR, { mutate } from "swr"; -import { HealthCheckBanner } from "@/components/health/healthcheck"; import Title from "@/components/ui/title"; import { AdminPageTitle } from "@/components/admin/Title"; @@ -19,12 +18,12 @@ import AdvancedFormPage from "./pages/Advanced"; import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm"; import CreateCredential from "@/components/credentials/actions/CreateCredential"; import ModifyCredential from "@/components/credentials/actions/ModifyCredential"; +import { ConfigurableSources, oauthSupportedSources } from "@/lib/types"; import { - ConfigurableSources, - oauthSupportedSources, - ValidSources, -} from "@/lib/types"; -import { Credential, credentialTemplates } from "@/lib/connectors/credentials"; + Credential, + credentialTemplates, + OAuthDetails, +} from "@/lib/connectors/credentials"; import { ConnectionConfiguration, connectorConfigs, @@ -37,7 +36,6 @@ import { ConnectorBase, } from "@/lib/connectors/connectors"; import { Modal } from "@/components/Modal"; -import GDriveMain from "./pages/gdrive/GoogleDrivePage"; import { GmailMain } from "./pages/gmail/GmailPage"; import { useGmailCredentials, @@ -54,7 +52,12 @@ import { NEXT_PUBLIC_TEST_ENV, } from "@/lib/constants"; import TemporaryLoadingModal from "@/components/TemporaryLoadingModal"; -import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth"; +import { + getConnectorOauthRedirectUrl, + useOAuthDetails, +} from "@/lib/connectors/oauth"; +import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential"; +import { Spinner } from "@/components/Spinner"; export interface AdvancedConfig { refreshFreq: number; pruneFreq: number; @@ -144,7 +147,8 @@ export default function AddConnector({ // State for managing credentials and files const [currentCredential, setCurrentCredential] = useState | null>(null); - const [createConnectorToggle, setCreateConnectorToggle] = useState(false); + const [createCredentialFormToggle, setCreateCredentialFormToggle] = + useState(false); // Fetch credentials data const { data: credentials } = useSWR[]>( @@ -159,6 +163,9 @@ export default function AddConnector({ { refreshInterval: 5000 } ); + const { data: oauthDetails, isLoading: oauthDetailsLoading } = + useOAuthDetails(connector); + // Get credential template and configuration const credentialTemplate = credentialTemplates[connector]; const configuration: ConnectionConfiguration = connectorConfigs[connector]; @@ -450,19 +457,33 @@ export default function AddConnector({ onDeleteCredential={onDeleteCredential} onSwitch={onSwap} /> - {!createConnectorToggle && ( + {!createCredentialFormToggle && (
{/* Button to pop up a form to manually enter credentials */}
)} - {createConnectorToggle && ( + {createCredentialFormToggle && ( setCreateConnectorToggle(false)} + onOutsideClick={() => + setCreateCredentialFormToggle(false) + } > - <> - - Create a {getSourceDisplayName(connector)}{" "} - credential - - setCreateConnectorToggle(false)} - /> - + {oauthDetailsLoading ? ( + + ) : ( + <> + + Create a {getSourceDisplayName(connector)}{" "} + credential + + {oauthDetails && oauthDetails.oauth_enabled ? ( + + ) : ( + + setCreateCredentialFormToggle(false) + } + /> + )} + + )} )} diff --git a/web/src/components/credentials/CredentialSection.tsx b/web/src/components/credentials/CredentialSection.tsx index b20f2ea818..0b3a6c984f 100644 --- a/web/src/components/credentials/CredentialSection.tsx +++ b/web/src/components/credentials/CredentialSection.tsx @@ -28,7 +28,12 @@ import { ConfluenceCredentialJson, Credential, } from "@/lib/connectors/credentials"; -import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth"; +import { + getConnectorOauthRedirectUrl, + useOAuthDetails, +} from "@/lib/connectors/oauth"; +import { Spinner } from "@/components/Spinner"; +import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential"; export default function CredentialSection({ ccPair, @@ -39,16 +44,6 @@ export default function CredentialSection({ sourceType: ValidSources; refresh: () => void; }) { - const makeShowCreateCredential = async () => { - const redirectUrl = await getConnectorOauthRedirectUrl(sourceType); - if (redirectUrl) { - window.location.href = redirectUrl; - } else { - setShowModifyCredential(false); - setShowCreateCredential(true); - } - }; - const { data: credentials } = useSWR[]>( buildSimilarCredentialInfoURL(sourceType), errorHandlingFetcher, @@ -59,6 +54,28 @@ export default function CredentialSection({ errorHandlingFetcher, { refreshInterval: 5000 } ); + const { data: oauthDetails, isLoading: oauthDetailsLoading } = + useOAuthDetails(sourceType); + + const makeShowCreateCredential = async () => { + if (oauthDetailsLoading || !oauthDetails) { + return; + } + + if (oauthDetails.oauth_enabled) { + if (oauthDetails.additional_kwargs.length > 0) { + setShowCreateCredential(true); + } else { + const redirectUrl = await getConnectorOauthRedirectUrl(sourceType, {}); + if (redirectUrl) { + window.location.href = redirectUrl; + } + } + } else { + setShowModifyCredential(false); + setShowCreateCredential(true); + } + }; const onSwap = async ( selectedCredential: Credential, @@ -193,13 +210,26 @@ export default function CredentialSection({ className="max-w-3xl rounded-lg" title={`Create ${getSourceDisplayName(sourceType)} Credential`} > - + {oauthDetailsLoading ? ( + + ) : ( + <> + {oauthDetails && oauthDetails.oauth_enabled ? ( + + ) : ( + + )} + + )} )} diff --git a/web/src/components/credentials/actions/CreateStdOAuthCredential.tsx b/web/src/components/credentials/actions/CreateStdOAuthCredential.tsx new file mode 100644 index 0000000000..ff52375d54 --- /dev/null +++ b/web/src/components/credentials/actions/CreateStdOAuthCredential.tsx @@ -0,0 +1,88 @@ +import * as Yup from "yup"; + +import React from "react"; +import { Button } from "@/components/ui/button"; +import { ValidSources } from "@/lib/types"; +import { TextFormField } from "@/components/admin/connectors/Field"; +import { Form, Formik, FormikHelpers } from "formik"; +import CardSection from "@/components/admin/CardSection"; +import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth"; +import { OAuthAdditionalKwargDescription } from "@/lib/connectors/credentials"; + +type formType = { + [key: string]: any; // For additional credential fields +}; + +export function CreateStdOAuthCredential({ + sourceType, + additionalFields, +}: { + // Source information + sourceType: ValidSources; + + additionalFields: OAuthAdditionalKwargDescription[]; +}) { + const handleSubmit = async ( + values: formType, + formikHelpers: FormikHelpers + ) => { + const { setSubmitting, validateForm } = formikHelpers; + + const errors = await validateForm(values); + if (Object.keys(errors).length > 0) { + formikHelpers.setErrors(errors); + return; + } + + setSubmitting(true); + formikHelpers.setSubmitting(true); + + const redirectUrl = await getConnectorOauthRedirectUrl(sourceType, values); + + if (!redirectUrl) { + throw new Error("No redirect URL found for OAuth connector"); + } + + window.location.href = redirectUrl; + }; + + return ( + [field, ""])), + } as formType + } + validationSchema={Yup.object().shape({ + ...Object.fromEntries( + additionalFields.map((field) => [field.name, Yup.string().required()]) + ), + })} + onSubmit={(values, formikHelpers) => { + handleSubmit(values, formikHelpers); + }} + > + {() => ( +
+ + {additionalFields.map((field) => ( + + ))} + +
+ +
+
+
+ )} +
+ ); +} diff --git a/web/src/lib/connectors/credentials.ts b/web/src/lib/connectors/credentials.ts index e030e2578c..b1d1a18d89 100644 --- a/web/src/lib/connectors/credentials.ts +++ b/web/src/lib/connectors/credentials.ts @@ -1,5 +1,16 @@ import { ValidSources } from "../types"; +export interface OAuthAdditionalKwargDescription { + name: string; + display_name: string; + description: string; +} + +export interface OAuthDetails { + oauth_enabled: boolean; + additional_kwargs: OAuthAdditionalKwargDescription[]; +} + export interface CredentialBase { credential_json: T; admin_public: boolean; diff --git a/web/src/lib/connectors/oauth.ts b/web/src/lib/connectors/oauth.ts index dc5adee725..d3472ccba6 100644 --- a/web/src/lib/connectors/oauth.ts +++ b/web/src/lib/connectors/oauth.ts @@ -1,12 +1,18 @@ +import useSWR from "swr"; import { ValidSources } from "../types"; +import { OAuthDetails } from "./credentials"; +import { errorHandlingFetcher } from "../fetcher"; export async function getConnectorOauthRedirectUrl( - connector: ValidSources + connector: ValidSources, + additional_kwargs: Record ): Promise { + const queryParams = new URLSearchParams({ + desired_return_url: window.location.href, + ...additional_kwargs, + }); const response = await fetch( - `/api/connector/oauth/authorize/${connector}?desired_return_url=${encodeURIComponent( - window.location.href - )}` + `/api/connector/oauth/authorize/${connector}?${queryParams.toString()}` ); if (!response.ok) { @@ -17,3 +23,10 @@ export async function getConnectorOauthRedirectUrl( const data = await response.json(); return data.redirect_url as string; } + +export function useOAuthDetails(sourceType: ValidSources) { + return useSWR( + `/api/connector/oauth/details/${sourceType}`, + errorHandlingFetcher + ); +}