mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-21 05:20:55 +02:00
Linear OAuth Connector (#3570)
This commit is contained in:
parent
240f3e4fff
commit
ccd3983802
@ -58,6 +58,9 @@ SESSION_EXPIRE_TIME_SECONDS = int(
|
|||||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||||
) # 7 days
|
) # 7 days
|
||||||
|
|
||||||
|
# Default request timeout, mostly used by connectors
|
||||||
|
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
|
||||||
|
|
||||||
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
|
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
|
||||||
# restrict access to Onyx to only users with emails from those domains.
|
# restrict access to Onyx to only users with emails from those domains.
|
||||||
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
|
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
|
||||||
@ -367,12 +370,18 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
|||||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Typically set to http://localhost:3000 for OAuth connector development
|
||||||
|
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
|
||||||
|
|
||||||
# Egnyte specific configs
|
# Egnyte specific configs
|
||||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
|
||||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||||
|
|
||||||
|
# Linear specific configs
|
||||||
|
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||||
|
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||||
|
|
||||||
DASK_JOB_CLIENT_ENABLED = (
|
DASK_JOB_CLIENT_ENABLED = (
|
||||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||||
)
|
)
|
||||||
|
@ -6,6 +6,7 @@ from typing import TypeVar
|
|||||||
|
|
||||||
from dateutil.parser import parse
|
from dateutil.parser import parse
|
||||||
|
|
||||||
|
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||||
from onyx.configs.constants import IGNORE_FOR_QA
|
from onyx.configs.constants import IGNORE_FOR_QA
|
||||||
from onyx.connectors.models import BasicExpertInfo
|
from onyx.connectors.models import BasicExpertInfo
|
||||||
from onyx.utils.text_processing import is_valid_email
|
from onyx.utils.text_processing import is_valid_email
|
||||||
@ -71,3 +72,10 @@ def process_in_batches(
|
|||||||
|
|
||||||
def get_metadata_keys_to_ignore() -> list[str]:
|
def get_metadata_keys_to_ignore() -> list[str]:
|
||||||
return [IGNORE_FOR_QA]
|
return [IGNORE_FOR_QA]
|
||||||
|
|
||||||
|
|
||||||
|
def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||||
|
if CONNECTOR_LOCALHOST_OVERRIDE:
|
||||||
|
# Used for development
|
||||||
|
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
|
||||||
|
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"
|
||||||
|
@ -3,21 +3,18 @@ import os
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
from logging import Logger
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
|
||||||
from typing import IO
|
from typing import IO
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
import requests
|
|
||||||
from retry import retry
|
|
||||||
|
|
||||||
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
|
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||||
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
|
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
|
||||||
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
|
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||||
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
|
|
||||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
|
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||||
|
get_oauth_callback_uri,
|
||||||
|
)
|
||||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||||
from onyx.connectors.interfaces import LoadConnector
|
from onyx.connectors.interfaces import LoadConnector
|
||||||
from onyx.connectors.interfaces import OAuthConnector
|
from onyx.connectors.interfaces import OAuthConnector
|
||||||
@ -34,54 +31,13 @@ from onyx.file_processing.extract_file_text import is_text_file_extension
|
|||||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||||
from onyx.file_processing.extract_file_text import read_text_file
|
from onyx.file_processing.extract_file_text import read_text_file
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
from onyx.utils.retry_wrapper import request_with_retries
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
||||||
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
||||||
_TIMEOUT = 60
|
|
||||||
|
|
||||||
|
|
||||||
def _request_with_retries(
|
|
||||||
method: str,
|
|
||||||
url: str,
|
|
||||||
data: dict[str, Any] | None = None,
|
|
||||||
headers: dict[str, Any] | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
timeout: int = _TIMEOUT,
|
|
||||||
stream: bool = False,
|
|
||||||
tries: int = 8,
|
|
||||||
delay: float = 1,
|
|
||||||
backoff: float = 2,
|
|
||||||
) -> requests.Response:
|
|
||||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
|
||||||
def _make_request() -> requests.Response:
|
|
||||||
response = requests.request(
|
|
||||||
method,
|
|
||||||
url,
|
|
||||||
data=data,
|
|
||||||
headers=headers,
|
|
||||||
params=params,
|
|
||||||
timeout=timeout,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response.raise_for_status()
|
|
||||||
except requests.exceptions.HTTPError as e:
|
|
||||||
if e.response.status_code != 403:
|
|
||||||
logger.exception(
|
|
||||||
f"Failed to call Egnyte API.\n"
|
|
||||||
f"URL: {url}\n"
|
|
||||||
# NOTE: can't log headers because they contain the access token
|
|
||||||
# f"Headers: {headers}\n"
|
|
||||||
f"Data: {data}\n"
|
|
||||||
f"Params: {params}"
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
return response
|
|
||||||
|
|
||||||
return _make_request()
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_last_modified(last_modified: str) -> datetime:
|
def _parse_last_modified(last_modified: str) -> datetime:
|
||||||
@ -189,10 +145,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
if not EGNYTE_BASE_DOMAIN:
|
if not EGNYTE_BASE_DOMAIN:
|
||||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||||
|
|
||||||
if EGNYTE_LOCALHOST_OVERRIDE:
|
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||||
base_domain = EGNYTE_LOCALHOST_OVERRIDE
|
|
||||||
|
|
||||||
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
|
|
||||||
return (
|
return (
|
||||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||||
@ -213,7 +166,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
|
|
||||||
# Exchange code for token
|
# Exchange code for token
|
||||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||||
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
|
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||||
data = {
|
data = {
|
||||||
"client_id": EGNYTE_CLIENT_ID,
|
"client_id": EGNYTE_CLIENT_ID,
|
||||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||||
@ -224,7 +177,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
}
|
}
|
||||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||||
|
|
||||||
response = _request_with_retries(
|
response = request_with_retries(
|
||||||
method="POST",
|
method="POST",
|
||||||
url=url,
|
url=url,
|
||||||
data=data,
|
data=data,
|
||||||
@ -264,8 +217,8 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
|
|
||||||
url_encoded_path = quote(path or "", safe="")
|
url_encoded_path = quote(path or "", safe="")
|
||||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
||||||
response = _request_with_retries(
|
response = request_with_retries(
|
||||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
method="GET", url=url, headers=headers, params=params
|
||||||
)
|
)
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
||||||
@ -320,11 +273,10 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
|||||||
}
|
}
|
||||||
url_encoded_path = quote(file["path"], safe="")
|
url_encoded_path = quote(file["path"], safe="")
|
||||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
||||||
response = _request_with_retries(
|
response = request_with_retries(
|
||||||
method="GET",
|
method="GET",
|
||||||
url=url,
|
url=url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=_TIMEOUT,
|
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7,16 +7,23 @@ from typing import cast
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
|
from onyx.configs.app_configs import LINEAR_CLIENT_ID
|
||||||
|
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
|
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||||
|
get_oauth_callback_uri,
|
||||||
|
)
|
||||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||||
from onyx.connectors.interfaces import LoadConnector
|
from onyx.connectors.interfaces import LoadConnector
|
||||||
|
from onyx.connectors.interfaces import OAuthConnector
|
||||||
from onyx.connectors.interfaces import PollConnector
|
from onyx.connectors.interfaces import PollConnector
|
||||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||||
from onyx.connectors.models import Document
|
from onyx.connectors.models import Document
|
||||||
from onyx.connectors.models import Section
|
from onyx.connectors.models import Section
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
from onyx.utils.retry_wrapper import request_with_retries
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -57,7 +64,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LinearConnector(LoadConnector, PollConnector):
|
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
batch_size: int = INDEX_BATCH_SIZE,
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
@ -65,8 +72,64 @@ class LinearConnector(LoadConnector, PollConnector):
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.linear_api_key: str | None = None
|
self.linear_api_key: str | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def oauth_id(cls) -> DocumentSource:
|
||||||
|
return DocumentSource.LINEAR
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||||
|
if not LINEAR_CLIENT_ID:
|
||||||
|
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
|
||||||
|
|
||||||
|
callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value)
|
||||||
|
return (
|
||||||
|
f"https://linear.app/oauth/authorize"
|
||||||
|
f"?client_id={LINEAR_CLIENT_ID}"
|
||||||
|
f"&redirect_uri={callback_uri}"
|
||||||
|
f"&response_type=code"
|
||||||
|
f"&scope=read"
|
||||||
|
f"&state={state}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||||
|
data = {
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": get_oauth_callback_uri(
|
||||||
|
base_domain, DocumentSource.LINEAR.value
|
||||||
|
),
|
||||||
|
"client_id": LINEAR_CLIENT_ID,
|
||||||
|
"client_secret": LINEAR_CLIENT_SECRET,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
}
|
||||||
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||||
|
|
||||||
|
response = request_with_retries(
|
||||||
|
method="POST",
|
||||||
|
url="https://api.linear.app/oauth/token",
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
backoff=0,
|
||||||
|
delay=0.1,
|
||||||
|
)
|
||||||
|
if not response.ok:
|
||||||
|
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
|
||||||
|
|
||||||
|
token_data = response.json()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": token_data["access_token"],
|
||||||
|
}
|
||||||
|
|
||||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
if "linear_api_key" in credentials:
|
||||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||||
|
elif "access_token" in credentials:
|
||||||
|
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
|
||||||
|
else:
|
||||||
|
# May need to handle case in the future if the OAuth flow expires
|
||||||
|
raise ConnectorMissingCredentialError("Linear")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _process_issues(
|
def _process_issues(
|
||||||
|
@ -199,7 +199,7 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
|
|||||||
# value = redis_client.get('key')
|
# value = redis_client.get('key')
|
||||||
# print(value.decode()) # Output: 'value'
|
# print(value.decode()) # Output: 'value'
|
||||||
|
|
||||||
_async_redis_connection = None
|
_async_redis_connection: aioredis.Redis | None = None
|
||||||
_async_lock = asyncio.Lock()
|
_async_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,8 +4,10 @@ from typing import Any
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import requests
|
||||||
from retry import retry
|
from retry import retry
|
||||||
|
|
||||||
|
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -42,3 +44,48 @@ def retry_builder(
|
|||||||
return cast(F, wrapped_func)
|
return cast(F, wrapped_func)
|
||||||
|
|
||||||
return retry_with_default
|
return retry_with_default
|
||||||
|
|
||||||
|
|
||||||
|
def request_with_retries(
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
headers: dict[str, Any] | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
timeout: int = REQUEST_TIMEOUT_SECONDS,
|
||||||
|
stream: bool = False,
|
||||||
|
tries: int = 8,
|
||||||
|
delay: float = 1,
|
||||||
|
backoff: float = 2,
|
||||||
|
) -> requests.Response:
|
||||||
|
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||||
|
def _make_request() -> requests.Response:
|
||||||
|
response = requests.request(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
params=params,
|
||||||
|
timeout=timeout,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.exceptions.HTTPError:
|
||||||
|
logger.exception(
|
||||||
|
"Request failed:\n%s",
|
||||||
|
{
|
||||||
|
"method": method,
|
||||||
|
"url": url,
|
||||||
|
"data": data,
|
||||||
|
"headers": headers,
|
||||||
|
"params": params,
|
||||||
|
"timeout": timeout,
|
||||||
|
"stream": stream,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
return response
|
||||||
|
|
||||||
|
return _make_request()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user