mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 21:09:51 +02:00
Egnyte connector (#3420)
This commit is contained in:
parent
fe83f676df
commit
4e4214b82c
@ -348,6 +348,12 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
@ -132,6 +132,7 @@ class DocumentSource(str, Enum):
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
373
backend/danswer/connectors/egnyte/connector.py
Normal file
373
backend/danswer/connectors/egnyte/connector.py
Normal file
@ -0,0 +1,373 @@
|
||||
import io
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||
from danswer.configs.app_configs import EGNYTE_CLIENT_ID
|
||||
from danswer.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import OAuthConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import detect_encoding
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.extract_file_text import get_file_ext
|
||||
from danswer.file_processing.extract_file_text import is_text_file_extension
|
||||
from danswer.file_processing.extract_file_text import is_valid_file_ext
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
||||
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
||||
_TIMEOUT = 60
|
||||
|
||||
|
||||
def _request_with_retries(
|
||||
method: str,
|
||||
url: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: int = _TIMEOUT,
|
||||
stream: bool = False,
|
||||
tries: int = 8,
|
||||
delay: float = 1,
|
||||
backoff: float = 2,
|
||||
) -> requests.Response:
|
||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||
def _make_request() -> requests.Response:
|
||||
response = requests.request(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
return _make_request()
|
||||
|
||||
|
||||
def _parse_last_modified(last_modified: str) -> datetime:
|
||||
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
|
||||
def _process_egnyte_file(
|
||||
file_metadata: dict[str, Any],
|
||||
file_content: IO,
|
||||
base_url: str,
|
||||
folder_path: str | None = None,
|
||||
) -> Document | None:
|
||||
"""Process an Egnyte file into a Document object
|
||||
|
||||
Args:
|
||||
file_data: The file data from Egnyte API
|
||||
file_content: The raw content of the file in bytes
|
||||
base_url: The base URL for the Egnyte instance
|
||||
folder_path: Optional folder path to filter results
|
||||
"""
|
||||
# Skip if file path doesn't match folder path filter
|
||||
if folder_path and not file_metadata["path"].startswith(folder_path):
|
||||
raise ValueError(
|
||||
f"File path {file_metadata['path']} does not match folder path {folder_path}"
|
||||
)
|
||||
|
||||
file_name = file_metadata["name"]
|
||||
extension = get_file_ext(file_name)
|
||||
if not is_valid_file_ext(extension):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return None
|
||||
|
||||
# Extract text content based on file type
|
||||
if is_text_file_extension(file_name):
|
||||
encoding = detect_encoding(file_content)
|
||||
file_content_raw, file_metadata = read_text_file(
|
||||
file_content, encoding=encoding, ignore_danswer_metadata=False
|
||||
)
|
||||
else:
|
||||
file_content_raw = extract_file_text(
|
||||
file=file_content,
|
||||
file_name=file_name,
|
||||
break_on_unprocessable=True,
|
||||
)
|
||||
|
||||
# Build the web URL for the file
|
||||
web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}"
|
||||
|
||||
# Create document metadata
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"file_path": file_metadata["path"],
|
||||
"last_modified": file_metadata.get("last_modified", ""),
|
||||
}
|
||||
|
||||
# Add lock info if present
|
||||
if lock_info := file_metadata.get("lock_info"):
|
||||
metadata[
|
||||
"lock_owner"
|
||||
] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}"
|
||||
|
||||
# Create the document owners
|
||||
primary_owner = None
|
||||
if uploaded_by := file_metadata.get("uploaded_by"):
|
||||
primary_owner = BasicExpertInfo(
|
||||
email=uploaded_by, # Using username as email since that's what we have
|
||||
)
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=f"egnyte-{file_metadata['entry_id']}",
|
||||
sections=[Section(text=file_content_raw.strip(), link=web_url)],
|
||||
source=DocumentSource.EGNYTE,
|
||||
semantic_identifier=file_name,
|
||||
metadata=metadata,
|
||||
doc_updated_at=(
|
||||
_parse_last_modified(file_metadata["last_modified"])
|
||||
if "last_modified" in file_metadata
|
||||
else None
|
||||
),
|
||||
primary_owners=[primary_owner] if primary_owner else None,
|
||||
)
|
||||
|
||||
|
||||
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
def __init__(
|
||||
self,
|
||||
folder_path: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.domain = "" # will always be set in `load_credentials`
|
||||
self.folder_path = folder_path or "" # Root folder if not specified
|
||||
self.batch_size = batch_size
|
||||
self.access_token: str | None = None
|
||||
|
||||
@classmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
return DocumentSource.EGNYTE
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
if EGNYTE_LOCALHOST_OVERRIDE:
|
||||
base_domain = EGNYTE_LOCALHOST_OVERRIDE
|
||||
|
||||
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
|
||||
return (
|
||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&scope=Egnyte.filesystem"
|
||||
f"&state={state}"
|
||||
f"&response_type=code"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_CLIENT_SECRET:
|
||||
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
# Exchange code for token
|
||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
data = {
|
||||
"client_id": EGNYTE_CLIENT_ID,
|
||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": f"{EGNYTE_LOCALHOST_OVERRIDE or ''}/connector/oauth/callback/egnyte",
|
||||
"scope": "Egnyte.filesystem",
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = _request_with_retries(
|
||||
method="POST",
|
||||
url=url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
# try a lot faster since this is a realtime flow
|
||||
backoff=0,
|
||||
delay=0.1,
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
|
||||
|
||||
token_data = response.json()
|
||||
return {
|
||||
"domain": EGNYTE_BASE_DOMAIN,
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.domain = credentials["domain"]
|
||||
self.access_token = credentials["access_token"]
|
||||
return None
|
||||
|
||||
def _get_files_list(
|
||||
self,
|
||||
path: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not self.access_token or not self.domain:
|
||||
raise ConnectorMissingCredentialError("Egnyte")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"list_content": True,
|
||||
}
|
||||
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
|
||||
response = _request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
all_files: list[dict[str, Any]] = []
|
||||
|
||||
# Add files from current directory
|
||||
all_files.extend(data.get("files", []))
|
||||
|
||||
# Recursively traverse folders
|
||||
for item in data.get("folders", []):
|
||||
all_files.extend(self._get_files_list(item["path"]))
|
||||
|
||||
return all_files
|
||||
|
||||
def _filter_files(
|
||||
self,
|
||||
files: list[dict[str, Any]],
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
filtered_files = []
|
||||
for file in files:
|
||||
if file["is_folder"]:
|
||||
continue
|
||||
|
||||
file_modified = _parse_last_modified(file["last_modified"])
|
||||
if start_time and file_modified < start_time:
|
||||
continue
|
||||
if end_time and file_modified > end_time:
|
||||
continue
|
||||
|
||||
filtered_files.append(file)
|
||||
|
||||
return filtered_files
|
||||
|
||||
def _process_files(
|
||||
self,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
files = self._get_files_list(self.folder_path)
|
||||
files = self._filter_files(files, start_time, end_time)
|
||||
|
||||
current_batch: list[Document] = []
|
||||
for file in files:
|
||||
try:
|
||||
# Set up request with streaming enabled
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
|
||||
response = _request_with_retries(
|
||||
method="GET",
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=_TIMEOUT,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
logger.error(
|
||||
f"Failed to fetch file content: {file['path']} (status code: {response.status_code})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Stream the response content into a BytesIO buffer
|
||||
buffer = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
buffer.write(chunk)
|
||||
|
||||
# Reset buffer's position to the start
|
||||
buffer.seek(0)
|
||||
|
||||
# Process the streamed file content
|
||||
doc = _process_egnyte_file(
|
||||
file_metadata=file,
|
||||
file_content=buffer,
|
||||
base_url=_EGNYTE_APP_BASE.format(domain=self.domain),
|
||||
folder_path=self.folder_path,
|
||||
)
|
||||
|
||||
if doc is not None:
|
||||
current_batch.append(doc)
|
||||
|
||||
if len(current_batch) >= self.batch_size:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process file {file['path']}")
|
||||
continue
|
||||
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
yield from self._process_files()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
yield from self._process_files(start_time=start_time, end_time=end_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = EgnyteConnector()
|
||||
connector.load_credentials(
|
||||
{
|
||||
"domain": os.environ["EGNYTE_DOMAIN"],
|
||||
"access_token": os.environ["EGNYTE_ACCESS_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
@ -15,6 +15,7 @@ from danswer.connectors.danswer_jira.connector import JiraConnector
|
||||
from danswer.connectors.discourse.connector import DiscourseConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.dropbox.connector import DropboxConnector
|
||||
from danswer.connectors.egnyte.connector import EgnyteConnector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.fireflies.connector import FirefliesConnector
|
||||
from danswer.connectors.freshdesk.connector import FreshdeskConnector
|
||||
@ -103,6 +104,7 @@ def identify_connector_class(
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
@ -17,11 +17,11 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
|
||||
from danswer.file_processing.extract_file_text import detect_encoding
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.extract_file_text import get_file_ext
|
||||
from danswer.file_processing.extract_file_text import is_text_file_extension
|
||||
from danswer.file_processing.extract_file_text import is_valid_file_ext
|
||||
from danswer.file_processing.extract_file_text import load_files_from_zip
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
@ -50,7 +50,7 @@ def _read_files_and_metadata(
|
||||
file_content, ignore_dirs=True
|
||||
):
|
||||
yield os.path.join(directory_path, file_info.filename), file, metadata
|
||||
elif check_file_ext_is_valid(extension):
|
||||
elif is_valid_file_ext(extension):
|
||||
yield file_name, file_content, metadata
|
||||
else:
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
@ -63,7 +63,7 @@ def _process_file(
|
||||
pdf_pass: str | None = None,
|
||||
) -> list[Document]:
|
||||
extension = get_file_ext(file_name)
|
||||
if not check_file_ext_is_valid(extension):
|
||||
if not is_valid_file_ext(extension):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return []
|
||||
|
||||
|
@ -2,6 +2,7 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import SlimDocument
|
||||
|
||||
@ -64,6 +65,23 @@ class SlimConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OAuthConnector(BaseConnector):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
@ -70,7 +70,7 @@ def get_file_ext(file_path_or_name: str | Path) -> str:
|
||||
return extension
|
||||
|
||||
|
||||
def check_file_ext_is_valid(ext: str) -> bool:
|
||||
def is_valid_file_ext(ext: str) -> bool:
|
||||
return ext in VALID_FILE_EXTENSIONS
|
||||
|
||||
|
||||
@ -364,7 +364,7 @@ def extract_file_text(
|
||||
elif file_name is not None:
|
||||
final_extension = get_file_ext(file_name)
|
||||
|
||||
if check_file_ext_is_valid(final_extension):
|
||||
if is_valid_file_ext(final_extension):
|
||||
return extension_to_function.get(final_extension, file_io_to_text)(file)
|
||||
|
||||
# Either the file somehow has no name or the extension is not one that we recognize
|
||||
|
@ -52,6 +52,7 @@ from danswer.server.documents.connector import router as connector_router
|
||||
from danswer.server.documents.credential import router as credential_router
|
||||
from danswer.server.documents.document import router as document_router
|
||||
from danswer.server.documents.indexing import router as indexing_router
|
||||
from danswer.server.documents.standard_oauth import router as oauth_router
|
||||
from danswer.server.features.document_set.api import router as document_set_router
|
||||
from danswer.server.features.folder.api import router as folder_router
|
||||
from danswer.server.features.notifications.api import router as notification_router
|
||||
@ -276,6 +277,7 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
146
backend/danswer/server/documents/standard_oauth.py
Normal file
146
backend/danswer/server/documents/standard_oauth.py
Normal file
@ -0,0 +1,146 @@
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import OAuthConnector
|
||||
from danswer.db.credentials import create_credential
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.subclasses import find_all_subclasses_in_dir
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/connector/oauth")
|
||||
|
||||
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
|
||||
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
|
||||
|
||||
# Cache for OAuth connectors, populated at module load time
|
||||
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
|
||||
|
||||
|
||||
def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
|
||||
"""Walk through the connectors package to find all OAuthConnector implementations"""
|
||||
global _OAUTH_CONNECTORS
|
||||
if _OAUTH_CONNECTORS: # Return cached connectors if already discovered
|
||||
return _OAUTH_CONNECTORS
|
||||
|
||||
oauth_connectors = find_all_subclasses_in_dir(
|
||||
cast(type[OAuthConnector], OAuthConnector), "danswer.connectors"
|
||||
)
|
||||
|
||||
_OAUTH_CONNECTORS = {cls.oauth_id(): cls for cls in oauth_connectors}
|
||||
return _OAUTH_CONNECTORS
|
||||
|
||||
|
||||
# Discover OAuth connectors at module load time
|
||||
_discover_oauth_connectors()
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.get("/authorize/{source}")
|
||||
def oauth_authorize(
|
||||
request: Request,
|
||||
source: DocumentSource,
|
||||
desired_return_url: Annotated[str | None, Query()] = None,
|
||||
_: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""Initiates the OAuth flow by redirecting to the provider's auth page"""
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
|
||||
|
||||
connector_cls = oauth_connectors[source]
|
||||
base_url = str(request.base_url)
|
||||
if "127.0.0.1" in base_url:
|
||||
base_url = base_url.replace("127.0.0.1", "localhost")
|
||||
|
||||
# store state in redis
|
||||
if not desired_return_url:
|
||||
desired_return_url = f"{WEB_DOMAIN}/admin/connectors/{source}?step=0"
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
state = str(uuid.uuid4())
|
||||
redis_client.set(
|
||||
_OAUTH_STATE_KEY_FMT.format(state=state),
|
||||
desired_return_url,
|
||||
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
|
||||
)
|
||||
|
||||
return AuthorizeResponse(
|
||||
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
|
||||
)
|
||||
|
||||
|
||||
class CallbackResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.get("/callback/{source}")
|
||||
def oauth_callback(
|
||||
source: DocumentSource,
|
||||
code: Annotated[str, Query()],
|
||||
state: Annotated[str, Query()],
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> CallbackResponse:
|
||||
"""Handles the OAuth callback and exchanges the code for tokens"""
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
|
||||
|
||||
connector_cls = oauth_connectors[source]
|
||||
|
||||
# get state from redis
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
original_url_bytes = cast(
|
||||
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
|
||||
)
|
||||
if not original_url_bytes:
|
||||
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
||||
original_url = original_url_bytes.decode("utf-8")
|
||||
|
||||
token_info = connector_cls.oauth_code_to_token(code)
|
||||
|
||||
# Create a new credential with the token info
|
||||
credential_data = CredentialBase(
|
||||
credential_json=token_info,
|
||||
admin_public=True, # Or based on some logic/parameter
|
||||
source=source,
|
||||
name=f"{source.title()} OAuth Credential",
|
||||
)
|
||||
|
||||
credential = create_credential(
|
||||
credential_data=credential_data,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return CallbackResponse(
|
||||
redirect_url=(
|
||||
f"{original_url}?credentialId={credential.id}"
|
||||
if "?" not in original_url
|
||||
else f"{original_url}&credentialId={credential.id}"
|
||||
)
|
||||
)
|
77
backend/danswer/utils/subclasses.py
Normal file
77
backend/danswer/utils/subclasses.py
Normal file
@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def import_all_modules_from_dir(dir_path: str) -> List[ModuleType]:
|
||||
"""
|
||||
Imports all modules found in the given directory and its subdirectories,
|
||||
returning a list of imported module objects.
|
||||
"""
|
||||
dir_path = os.path.abspath(dir_path)
|
||||
|
||||
if dir_path not in sys.path:
|
||||
sys.path.insert(0, dir_path)
|
||||
|
||||
imported_modules: List[ModuleType] = []
|
||||
|
||||
for _, package_name, _ in pkgutil.walk_packages([dir_path]):
|
||||
try:
|
||||
module = importlib.import_module(package_name)
|
||||
imported_modules.append(module)
|
||||
except Exception as e:
|
||||
# Handle or log exceptions as needed
|
||||
print(f"Could not import {package_name}: {e}")
|
||||
|
||||
return imported_modules
|
||||
|
||||
|
||||
def all_subclasses(cls: Type[T]) -> List[Type[T]]:
|
||||
"""
|
||||
Recursively find all subclasses of the given class.
|
||||
"""
|
||||
direct_subs = cls.__subclasses__()
|
||||
result: List[Type[T]] = []
|
||||
for subclass in direct_subs:
|
||||
result.append(subclass)
|
||||
# Extend the result by recursively calling all_subclasses
|
||||
result.extend(all_subclasses(subclass))
|
||||
return result
|
||||
|
||||
|
||||
def find_all_subclasses_in_dir(parent_class: Type[T], directory: str) -> List[Type[T]]:
|
||||
"""
|
||||
Imports all modules from the given directory (and subdirectories),
|
||||
then returns all classes that are subclasses of parent_class.
|
||||
|
||||
:param parent_class: The class to find subclasses of.
|
||||
:param directory: The directory to search for subclasses.
|
||||
:return: A list of all subclasses of parent_class found in the directory.
|
||||
"""
|
||||
# First import all modules to ensure classes are loaded into memory
|
||||
import_all_modules_from_dir(directory)
|
||||
|
||||
# Gather all subclasses of the given parent class
|
||||
subclasses = all_subclasses(parent_class)
|
||||
return subclasses
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
|
||||
class Animal:
|
||||
pass
|
||||
|
||||
# Suppose "mymodules" contains files that define classes inheriting from Animal
|
||||
found_subclasses = find_all_subclasses_in_dir(Animal, "mymodules")
|
||||
for sc in found_subclasses:
|
||||
print("Found subclass:", sc.__name__)
|
BIN
web/public/Egnyte.png
Normal file
BIN
web/public/Egnyte.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
@ -49,6 +49,7 @@ import { useRouter } from "next/navigation";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils";
|
||||
import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
|
||||
export interface AdvancedConfig {
|
||||
refreshFreq: number;
|
||||
pruneFreq: number;
|
||||
@ -442,11 +443,19 @@ export default function AddConnector({
|
||||
{/* Button to pop up a form to manually enter credentials */}
|
||||
<button
|
||||
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
|
||||
onClick={() =>
|
||||
setCreateConnectorToggle(
|
||||
(createConnectorToggle) => !createConnectorToggle
|
||||
)
|
||||
}
|
||||
onClick={async () => {
|
||||
const redirectUrl =
|
||||
await getConnectorOauthRedirectUrl(connector);
|
||||
// if redirect is supported, just use it
|
||||
if (redirectUrl) {
|
||||
window.location.href = redirectUrl;
|
||||
} else {
|
||||
setCreateConnectorToggle(
|
||||
(createConnectorToggle) =>
|
||||
!createConnectorToggle
|
||||
);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Create New
|
||||
</button>
|
||||
|
50
web/src/app/connector/oauth/callback/[source]/route.tsx
Normal file
50
web/src/app/connector/oauth/callback/[source]/route.tsx
Normal file
@ -0,0 +1,50 @@
|
||||
import { INTERNAL_URL } from "@/lib/constants";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
// TODO: deprecate this and just go directly to the backend via /api/...
|
||||
// For some reason Egnyte doesn't work when using /api, so leaving this as is for now
|
||||
// If we do try and remove this, make sure we test the Egnyte connector oauth flow
|
||||
export async function GET(request: NextRequest) {
|
||||
if (process.env.NODE_ENV !== "development") {
|
||||
return NextResponse.json(
|
||||
{ message: "This API is only available in development mode." },
|
||||
{ status: 404 }
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const backendUrl = new URL(INTERNAL_URL);
|
||||
// Copy path and query parameters from incoming request
|
||||
backendUrl.pathname = request.nextUrl.pathname;
|
||||
backendUrl.search = request.nextUrl.search;
|
||||
|
||||
const response = await fetch(backendUrl, {
|
||||
method: "GET",
|
||||
headers: request.headers,
|
||||
body: request.body,
|
||||
signal: request.signal,
|
||||
// @ts-ignore
|
||||
duplex: "half",
|
||||
});
|
||||
|
||||
const responseData = await response.json();
|
||||
if (responseData.redirect_url) {
|
||||
return NextResponse.redirect(responseData.redirect_url);
|
||||
}
|
||||
|
||||
return new NextResponse(JSON.stringify(responseData), {
|
||||
status: response.status,
|
||||
headers: response.headers,
|
||||
});
|
||||
} catch (error: unknown) {
|
||||
console.error("Proxy error:", error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
message: "Proxy error",
|
||||
error:
|
||||
error instanceof Error ? error.message : "An unknown error occurred",
|
||||
},
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
@ -28,6 +28,7 @@ import {
|
||||
ConfluenceCredentialJson,
|
||||
Credential,
|
||||
} from "@/lib/connectors/credentials";
|
||||
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
|
||||
|
||||
export default function CredentialSection({
|
||||
ccPair,
|
||||
@ -38,9 +39,14 @@ export default function CredentialSection({
|
||||
sourceType: ValidSources;
|
||||
refresh: () => void;
|
||||
}) {
|
||||
const makeShowCreateCredential = () => {
|
||||
setShowModifyCredential(false);
|
||||
setShowCreateCredential(true);
|
||||
const makeShowCreateCredential = async () => {
|
||||
const redirectUrl = await getConnectorOauthRedirectUrl(sourceType);
|
||||
if (redirectUrl) {
|
||||
window.location.href = redirectUrl;
|
||||
} else {
|
||||
setShowModifyCredential(false);
|
||||
setShowCreateCredential(true);
|
||||
}
|
||||
};
|
||||
|
||||
const { data: credentials } = useSWR<Credential<ConfluenceCredentialJson>[]>(
|
||||
@ -150,9 +156,6 @@ export default function CredentialSection({
|
||||
title="Update Credentials"
|
||||
>
|
||||
<ModifyCredential
|
||||
showCreate={() => {
|
||||
setShowCreateCredential(true);
|
||||
}}
|
||||
close={closeModifyCredential}
|
||||
source={sourceType}
|
||||
attachedConnector={ccPair.connector}
|
||||
|
@ -144,15 +144,12 @@ export default function ModifyCredential({
|
||||
attachedConnector,
|
||||
credentials,
|
||||
editableCredentials,
|
||||
source,
|
||||
defaultedCredential,
|
||||
|
||||
onSwap,
|
||||
onSwitch,
|
||||
onCreateNew = () => null,
|
||||
onEditCredential,
|
||||
onDeleteCredential,
|
||||
showCreate,
|
||||
onCreateNew,
|
||||
}: {
|
||||
close?: () => void;
|
||||
showIfEmpty?: boolean;
|
||||
@ -161,13 +158,11 @@ export default function ModifyCredential({
|
||||
credentials: Credential<any>[];
|
||||
editableCredentials: Credential<any>[];
|
||||
source: ValidSources;
|
||||
|
||||
onSwitch?: (newCredential: Credential<any>) => void;
|
||||
onSwap?: (newCredential: Credential<any>, connectorId: number) => void;
|
||||
onCreateNew?: () => void;
|
||||
onDeleteCredential: (credential: Credential<any | null>) => void;
|
||||
onEditCredential?: (credential: Credential<ConfluenceCredentialJson>) => void;
|
||||
showCreate?: () => void;
|
||||
}) {
|
||||
const [selectedCredential, setSelectedCredential] =
|
||||
useState<Credential<any> | null>(null);
|
||||
@ -244,10 +239,10 @@ export default function ModifyCredential({
|
||||
|
||||
{!showIfEmpty && (
|
||||
<div className="flex mt-8 justify-between">
|
||||
{showCreate ? (
|
||||
{onCreateNew ? (
|
||||
<Button
|
||||
onClick={() => {
|
||||
showCreate();
|
||||
onCreateNew();
|
||||
}}
|
||||
className="bg-neutral-500 disabled:border-transparent
|
||||
transition-colors duration-150 ease-in disabled:bg-neutral-300
|
||||
|
@ -62,6 +62,7 @@ import document360Icon from "../../../public/Document360.png";
|
||||
import googleSitesIcon from "../../../public/GoogleSites.png";
|
||||
import zendeskIcon from "../../../public/Zendesk.svg";
|
||||
import dropboxIcon from "../../../public/Dropbox.png";
|
||||
import egnyteIcon from "../../../public/Egnyte.png";
|
||||
import slackIcon from "../../../public/Slack.png";
|
||||
|
||||
import s3Icon from "../../../public/S3.png";
|
||||
@ -2725,3 +2726,17 @@ export const UserIcon = ({
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const EgnyteIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<div
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
>
|
||||
<Image src={egnyteIcon} alt="Egnyte" width="96" height="96" />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
@ -1050,6 +1050,21 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
|
||||
values: [],
|
||||
advanced_values: [],
|
||||
},
|
||||
egnyte: {
|
||||
description: "Configure Egnyte connector",
|
||||
values: [
|
||||
{
|
||||
type: "text",
|
||||
query: "Enter folder path to index:",
|
||||
label: "Folder Path",
|
||||
name: "folder_path",
|
||||
optional: true,
|
||||
description:
|
||||
"The folder path to index (e.g., '/Shared/Documents'). Leave empty to index everything.",
|
||||
},
|
||||
],
|
||||
advanced_values: [],
|
||||
},
|
||||
};
|
||||
export function createConnectorInitialValues(
|
||||
connector: ConfigurableSources
|
||||
|
@ -195,6 +195,11 @@ export interface FirefliesCredentialJson {
|
||||
export interface MediaWikiCredentialJson {}
|
||||
export interface WikipediaCredentialJson extends MediaWikiCredentialJson {}
|
||||
|
||||
export interface EgnyteCredentialJson {
|
||||
domain: string;
|
||||
access_token: string;
|
||||
}
|
||||
|
||||
export const credentialTemplates: Record<ValidSources, any> = {
|
||||
github: { github_access_token: "" } as GithubCredentialJson,
|
||||
gitlab: {
|
||||
@ -298,6 +303,10 @@ export const credentialTemplates: Record<ValidSources, any> = {
|
||||
fireflies: {
|
||||
fireflies_api_key: "",
|
||||
} as FirefliesCredentialJson,
|
||||
egnyte: {
|
||||
domain: "",
|
||||
access_token: "",
|
||||
} as EgnyteCredentialJson,
|
||||
xenforo: null,
|
||||
google_sites: null,
|
||||
file: null,
|
||||
|
19
web/src/lib/connectors/oauth.ts
Normal file
19
web/src/lib/connectors/oauth.ts
Normal file
@ -0,0 +1,19 @@
|
||||
import { ValidSources } from "../types";
|
||||
|
||||
export async function getConnectorOauthRedirectUrl(
|
||||
connector: ValidSources
|
||||
): Promise<string | null> {
|
||||
const response = await fetch(
|
||||
`/api/connector/oauth/authorize/${connector}?desired_return_url=${encodeURIComponent(
|
||||
window.location.href
|
||||
)}`
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
console.error(`Failed to fetch OAuth redirect URL for ${connector}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return data.redirect_url as string;
|
||||
}
|
@ -38,6 +38,7 @@ import {
|
||||
XenforoIcon,
|
||||
FreshdeskIcon,
|
||||
FirefliesIcon,
|
||||
EgnyteIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { ValidSources } from "./types";
|
||||
import {
|
||||
@ -304,6 +305,12 @@ export const SOURCE_METADATA_MAP: SourceMap = {
|
||||
displayName: "Not Applicable",
|
||||
category: SourceCategory.Other,
|
||||
},
|
||||
egnyte: {
|
||||
icon: EgnyteIcon,
|
||||
displayName: "Egnyte",
|
||||
category: SourceCategory.Storage,
|
||||
docs: "https://docs.danswer.dev/connectors/egnyte",
|
||||
},
|
||||
} as SourceMap;
|
||||
|
||||
function fillSourceMetadata(
|
||||
|
@ -309,6 +309,7 @@ export enum ValidSources {
|
||||
IngestionApi = "ingestion_api",
|
||||
Freshdesk = "freshdesk",
|
||||
Fireflies = "fireflies",
|
||||
Egnyte = "egnyte",
|
||||
}
|
||||
|
||||
export const validAutoSyncSources = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user