Linear OAuth Connector (#3570)

This commit is contained in:
Yuhong Sun 2024-12-31 16:11:09 -08:00 committed by GitHub
parent 240f3e4fff
commit ccd3983802
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 141 additions and 62 deletions

View File

@ -58,6 +58,9 @@ SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
# Default request timeout, mostly used by connectors
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
# restrict access to Onyx to only users with emails from those domains.
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
@ -367,12 +370,18 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
)
# Typically set to http://localhost:3000 for OAuth connector development
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
# Egnyte specific configs
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
# Linear specific configs
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)

View File

@ -6,6 +6,7 @@ from typing import TypeVar
from dateutil.parser import parse
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.connectors.models import BasicExpertInfo
from onyx.utils.text_processing import is_valid_email
@ -71,3 +72,10 @@ def process_in_batches(
def get_metadata_keys_to_ignore() -> list[str]:
return [IGNORE_FOR_QA]
def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
if CONNECTOR_LOCALHOST_OVERRIDE:
# Used for development
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"

View File

@ -3,21 +3,18 @@ import os
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from logging import Logger
from typing import Any
from typing import cast
from typing import IO
from urllib.parse import quote
import requests
from retry import retry
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_oauth_callback_uri,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import OAuthConnector
@ -34,54 +31,13 @@ from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
logger = setup_logger()
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
_TIMEOUT = 60
def _request_with_retries(
method: str,
url: str,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = _TIMEOUT,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method,
url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code != 403:
logger.exception(
f"Failed to call Egnyte API.\n"
f"URL: {url}\n"
# NOTE: can't log headers because they contain the access token
# f"Headers: {headers}\n"
f"Data: {data}\n"
f"Params: {params}"
)
raise e
return response
return _make_request()
def _parse_last_modified(last_modified: str) -> datetime:
@ -189,10 +145,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
if EGNYTE_LOCALHOST_OVERRIDE:
base_domain = EGNYTE_LOCALHOST_OVERRIDE
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
@ -213,7 +166,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
@ -224,7 +177,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = _request_with_retries(
response = request_with_retries(
method="POST",
url=url,
data=data,
@ -264,8 +217,8 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
url_encoded_path = quote(path or "", safe="")
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
response = _request_with_retries(
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
response = request_with_retries(
method="GET", url=url, headers=headers, params=params
)
if not response.ok:
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
@ -320,11 +273,10 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
}
url_encoded_path = quote(file["path"], safe="")
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
response = _request_with_retries(
response = request_with_retries(
method="GET",
url=url,
headers=headers,
timeout=_TIMEOUT,
stream=True,
)

View File

@ -7,16 +7,23 @@ from typing import cast
import requests
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import LINEAR_CLIENT_ID
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_oauth_callback_uri,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import OAuthConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
logger = setup_logger()
@ -57,7 +64,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
)
class LinearConnector(LoadConnector, PollConnector):
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@ -65,8 +72,64 @@ class LinearConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.linear_api_key: str | None = None
@classmethod
def oauth_id(cls) -> DocumentSource:
return DocumentSource.LINEAR
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
if not LINEAR_CLIENT_ID:
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value)
return (
f"https://linear.app/oauth/authorize"
f"?client_id={LINEAR_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&response_type=code"
f"&scope=read"
f"&state={state}"
)
@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
data = {
"code": code,
"redirect_uri": get_oauth_callback_uri(
base_domain, DocumentSource.LINEAR.value
),
"client_id": LINEAR_CLIENT_ID,
"client_secret": LINEAR_CLIENT_SECRET,
"grant_type": "authorization_code",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = request_with_retries(
method="POST",
url="https://api.linear.app/oauth/token",
data=data,
headers=headers,
backoff=0,
delay=0.1,
)
if not response.ok:
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
token_data = response.json()
return {
"access_token": token_data["access_token"],
}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.linear_api_key = cast(str, credentials["linear_api_key"])
if "linear_api_key" in credentials:
self.linear_api_key = cast(str, credentials["linear_api_key"])
elif "access_token" in credentials:
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
else:
# May need to handle case in the future if the OAuth flow expires
raise ConnectorMissingCredentialError("Linear")
return None
def _process_issues(

View File

@ -199,7 +199,7 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
# value = redis_client.get('key')
# print(value.decode()) # Output: 'value'
_async_redis_connection = None
_async_redis_connection: aioredis.Redis | None = None
_async_lock = asyncio.Lock()

View File

@ -4,8 +4,10 @@ from typing import Any
from typing import cast
from typing import TypeVar
import requests
from retry import retry
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -42,3 +44,48 @@ def retry_builder(
return cast(F, wrapped_func)
return retry_with_default
def request_with_retries(
method: str,
url: str,
*,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = REQUEST_TIMEOUT_SECONDS,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method=method,
url=url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError:
logger.exception(
"Request failed:\n%s",
{
"method": method,
"url": url,
"data": data,
"headers": headers,
"params": params,
"timeout": timeout,
"stream": stream,
},
)
raise
return response
return _make_request()