mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-19 20:40:57 +02:00
Linear OAuth Connector (#3570)
This commit is contained in:
parent
240f3e4fff
commit
ccd3983802
@ -58,6 +58,9 @@ SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# Default request timeout, mostly used by connectors
|
||||
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
|
||||
|
||||
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
|
||||
# restrict access to Onyx to only users with emails from those domains.
|
||||
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
|
||||
@ -367,12 +370,18 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Typically set to http://localhost:3000 for OAuth connector development
|
||||
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
# Linear specific configs
|
||||
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from typing import TypeVar
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.constants import IGNORE_FOR_QA
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.utils.text_processing import is_valid_email
|
||||
@ -71,3 +72,10 @@ def process_in_batches(
|
||||
|
||||
def get_metadata_keys_to_ignore() -> list[str]:
|
||||
return [IGNORE_FOR_QA]
|
||||
|
||||
|
||||
def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||
if CONNECTOR_LOCALHOST_OVERRIDE:
|
||||
# Used for development
|
||||
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
|
||||
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"
|
||||
|
@ -3,21 +3,18 @@ import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_oauth_callback_uri,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
@ -34,54 +31,13 @@ from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
||||
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
||||
_TIMEOUT = 60
|
||||
|
||||
|
||||
def _request_with_retries(
|
||||
method: str,
|
||||
url: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: int = _TIMEOUT,
|
||||
stream: bool = False,
|
||||
tries: int = 8,
|
||||
delay: float = 1,
|
||||
backoff: float = 2,
|
||||
) -> requests.Response:
|
||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||
def _make_request() -> requests.Response:
|
||||
response = requests.request(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code != 403:
|
||||
logger.exception(
|
||||
f"Failed to call Egnyte API.\n"
|
||||
f"URL: {url}\n"
|
||||
# NOTE: can't log headers because they contain the access token
|
||||
# f"Headers: {headers}\n"
|
||||
f"Data: {data}\n"
|
||||
f"Params: {params}"
|
||||
)
|
||||
raise e
|
||||
return response
|
||||
|
||||
return _make_request()
|
||||
|
||||
|
||||
def _parse_last_modified(last_modified: str) -> datetime:
|
||||
@ -189,10 +145,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
if EGNYTE_LOCALHOST_OVERRIDE:
|
||||
base_domain = EGNYTE_LOCALHOST_OVERRIDE
|
||||
|
||||
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
|
||||
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
return (
|
||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||
@ -213,7 +166,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
|
||||
# Exchange code for token
|
||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
|
||||
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
data = {
|
||||
"client_id": EGNYTE_CLIENT_ID,
|
||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||
@ -224,7 +177,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = _request_with_retries(
|
||||
response = request_with_retries(
|
||||
method="POST",
|
||||
url=url,
|
||||
data=data,
|
||||
@ -264,8 +217,8 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
|
||||
url_encoded_path = quote(path or "", safe="")
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
||||
response = _request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
||||
response = request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
||||
@ -320,11 +273,10 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
}
|
||||
url_encoded_path = quote(file["path"], safe="")
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
||||
response = _request_with_retries(
|
||||
response = request_with_retries(
|
||||
method="GET",
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=_TIMEOUT,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
@ -7,16 +7,23 @@ from typing import cast
|
||||
import requests
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_ID
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_oauth_callback_uri,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -57,7 +64,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
|
||||
)
|
||||
|
||||
|
||||
class LinearConnector(LoadConnector, PollConnector):
|
||||
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@ -65,8 +72,64 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.linear_api_key: str | None = None
|
||||
|
||||
@classmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
return DocumentSource.LINEAR
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
if not LINEAR_CLIENT_ID:
|
||||
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
|
||||
|
||||
callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value)
|
||||
return (
|
||||
f"https://linear.app/oauth/authorize"
|
||||
f"?client_id={LINEAR_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&response_type=code"
|
||||
f"&scope=read"
|
||||
f"&state={state}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
data = {
|
||||
"code": code,
|
||||
"redirect_uri": get_oauth_callback_uri(
|
||||
base_domain, DocumentSource.LINEAR.value
|
||||
),
|
||||
"client_id": LINEAR_CLIENT_ID,
|
||||
"client_secret": LINEAR_CLIENT_SECRET,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = request_with_retries(
|
||||
method="POST",
|
||||
url="https://api.linear.app/oauth/token",
|
||||
data=data,
|
||||
headers=headers,
|
||||
backoff=0,
|
||||
delay=0.1,
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
return {
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||
if "linear_api_key" in credentials:
|
||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||
elif "access_token" in credentials:
|
||||
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
|
||||
else:
|
||||
# May need to handle case in the future if the OAuth flow expires
|
||||
raise ConnectorMissingCredentialError("Linear")
|
||||
|
||||
return None
|
||||
|
||||
def _process_issues(
|
||||
|
@ -199,7 +199,7 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
# value = redis_client.get('key')
|
||||
# print(value.decode()) # Output: 'value'
|
||||
|
||||
_async_redis_connection = None
|
||||
_async_redis_connection: aioredis.Redis | None = None
|
||||
_async_lock = asyncio.Lock()
|
||||
|
||||
|
||||
|
@ -4,8 +4,10 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -42,3 +44,48 @@ def retry_builder(
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
return retry_with_default
|
||||
|
||||
|
||||
def request_with_retries(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: int = REQUEST_TIMEOUT_SECONDS,
|
||||
stream: bool = False,
|
||||
tries: int = 8,
|
||||
delay: float = 1,
|
||||
backoff: float = 2,
|
||||
) -> requests.Response:
|
||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||
def _make_request() -> requests.Response:
|
||||
response = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError:
|
||||
logger.exception(
|
||||
"Request failed:\n%s",
|
||||
{
|
||||
"method": method,
|
||||
"url": url,
|
||||
"data": data,
|
||||
"headers": headers,
|
||||
"params": params,
|
||||
"timeout": timeout,
|
||||
"stream": stream,
|
||||
},
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
return _make_request()
|
||||
|
Loading…
x
Reference in New Issue
Block a user