Add support for OAuth connectors that require user input (#3571)

* Add support for OAuth connectors that require user input

* Cleanup

* Fix linear

* Small re-naming

* Remove console.log
This commit is contained in:
Chris Weaver
2024-12-31 18:03:33 -08:00
committed by GitHub
parent ccd3983802
commit d64464ca7c
11 changed files with 389 additions and 81 deletions

View File

@@ -7,7 +7,8 @@ from typing import Any
from typing import IO
from urllib.parse import quote
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
from pydantic import Field
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@@ -124,6 +125,15 @@ def _process_egnyte_file(
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
egnyte_domain: str = Field(
title="Egnyte Domain",
description=(
"The domain for the Egnyte instance "
"(e.g. 'company' for company.egnyte.com)"
),
)
def __init__(
self,
folder_path: str | None = None,
@@ -139,15 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
return DocumentSource.EGNYTE
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&scope=Egnyte.filesystem"
@@ -156,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
)
@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_CLIENT_SECRET:
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
@@ -191,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
token_data = response.json()
return {
"domain": EGNYTE_BASE_DOMAIN,
"domain": oauth_kwargs.egnyte_domain,
"access_token": token_data["access_token"],
}
@@ -215,7 +236,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
"list_content": True,
}
url_encoded_path = quote(path or "", safe="")
url_encoded_path = quote(path or "")
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
response = request_with_retries(
method="GET", url=url, headers=headers, params=params
@@ -271,7 +292,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
headers = {
"Authorization": f"Bearer {self.access_token}",
}
url_encoded_path = quote(file["path"], safe="")
url_encoded_path = quote(file["path"])
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
response = request_with_retries(
method="GET",

View File

@@ -2,6 +2,8 @@ import abc
from collections.abc import Iterator
from typing import Any
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
@@ -66,6 +68,10 @@ class SlimConnector(BaseConnector):
class OAuthConnector(BaseConnector):
class AdditionalOauthKwargs(BaseModel):
# if overridden, all fields should be str type
pass
@classmethod
@abc.abstractmethod
def oauth_id(cls) -> DocumentSource:
@@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector):
@classmethod
@abc.abstractmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
raise NotImplementedError

View File

@@ -77,7 +77,9 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
return DocumentSource.LINEAR
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
) -> str:
if not LINEAR_CLIENT_ID:
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
@@ -92,7 +94,9 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
)
@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
) -> dict[str, Any]:
data = {
"code": code,
"redirect_uri": get_oauth_callback_uri(