Feature/google drive oauth (#3365)

* first cut at slack oauth flow

* fix usage of hooks

* fix button spacing

* add additional error logging

* no dev redirect

* early cut at google drive oauth

* second pass

* switch to production uri's

* try handling oauth_interactive differently

* pass through client id and secret if uploaded

* fix call

* fix test

* temporarily disable check for testing

* Revert "temporarily disable check for testing"

This reverts commit 4b5a022a5fe38b05355a561616068af8e969def2.

* support visibility in test

* missed file

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2024-12-12 10:01:59 -08:00 committed by GitHub
parent ca172f3306
commit dee1a0ecd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 1172 additions and 154 deletions

View File

@ -81,12 +81,6 @@ OAUTH_CLIENT_SECRET = (
or ""
)
# for future OAuth connector support
# OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
# OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
# OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
# OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
# for basic auth
@ -544,3 +538,5 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"

View File

@ -216,8 +216,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
return self._creds
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._primary_admin_email = primary_admin_email
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,

View File

@ -1,11 +1,14 @@
import json
from typing import cast
from typing import Any
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
@ -18,14 +21,42 @@ from danswer.connectors.google_utils.shared_constants import (
from danswer.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from danswer.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
logger = setup_logger()
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
"""we really don't want to be persisting the client id and secret anywhere but the
environment.
Returns a string of serialized json.
"""
# strip the client id and secret
oauth_creds_json_str = oauth_creds.to_json()
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
oauth_creds_sanitized_json.pop("client_id", None)
oauth_creds_sanitized_json.pop("client_secret", None)
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
return oauth_creds_sanitized_json_str
def get_google_oauth_creds(
token_json_str: str, source: DocumentSource
) -> OAuthCredentials | None:
"""creds_json only needs to contain client_id, client_secret and refresh_token to
refresh the creds.
expiry and token are optional ... however, if passing in expiry, token
should also be passed in or else we may not return any creds.
(probably a sign we should refactor the function)
"""
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(
info=creds_json,
@ -41,7 +72,7 @@ def get_google_oauth_creds(
logger.notice("Refreshed Google Drive tokens.")
return creds
except Exception:
logger.exception("Failed to refresh google drive access token due to:")
logger.exception("Failed to refresh google drive access token")
return None
return None
@ -52,31 +83,72 @@ def get_google_creds(
source: DocumentSource,
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
(1) A credential which holds a token acquired via a user going through
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
Return a tuple where:
The first element is the requested credentials
The second element is a new credentials dict that the caller should write back
to the db. This happens if token rotation occurs while loading credentials.
"""
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
# OAUTH
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_oauth_creds(
token_json_str=access_token_json_str, source=source
authentication_method: str = credentials.get(
DB_CREDENTIALS_AUTHENTICATION_METHOD,
GoogleOAuthAuthenticationMethod.UPLOADED.value,
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
],
}
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
credentials_dict = json.loads(credentials_dict_str)
# only send what get_google_oauth_creds needs
authorized_user_info = {}
# oauth_interactive is sanitized and needs credentials from the environment
if (
authentication_method
== GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
):
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
else:
authorized_user_info["client_id"] = credentials_dict["client_id"]
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
authorized_user_info["token"] = credentials_dict["token"]
authorized_user_info["expiry"] = credentials_dict["expiry"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=source
)
# tell caller to update token stored in DB if the refresh token changed
if oauth_creds:
if oauth_creds.refresh_token != authorized_user_info["refresh_token"]:
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
if (
authentication_method
== GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
):
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
else:
oauth_creds_json_str = oauth_creds.to_json()
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
],
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
# SERVICE ACCOUNT
service_account_key_json_str = credentials[

View File

@ -17,6 +17,9 @@ from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_utils.resources import get_drive_service
from danswer.connectors.google_utils.resources import get_gmail_service
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
@ -29,6 +32,9 @@ from danswer.connectors.google_utils.shared_constants import (
from danswer.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from danswer.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from danswer.connectors.google_utils.shared_constants import (
MISSING_SCOPES_ERROR_STR,
)
@ -96,6 +102,7 @@ def update_credential_access_tokens(
user: User,
db_session: Session,
source: DocumentSource,
auth_method: GoogleOAuthAuthenticationMethod,
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred(source)
flow = InstalledAppFlow.from_client_config(
@ -119,6 +126,7 @@ def update_credential_access_tokens(
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: auth_method.value,
}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
@ -129,6 +137,7 @@ def update_credential_access_tokens(
def build_service_account_creds(
source: DocumentSource,
primary_admin_email: str | None = None,
name: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key(source=source)
@ -138,10 +147,15 @@ def build_service_account_creds(
if primary_admin_email:
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.UPLOADED.value
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=source,
name=name,
)

View File

@ -1,3 +1,5 @@
from enum import Enum as PyEnum
from danswer.configs.constants import DocumentSource
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
@ -23,6 +25,19 @@ DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
# https://developers.google.com/workspace/guides/create-credentials
# Internally defined authentication method type.
# The value must be one of "oauth_interactive" or "uploaded"
# Used to disambiguate whether credentials have already been created via
# certain methods and what actions we allow users to take
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
class GoogleOAuthAuthenticationMethod(str, PyEnum):
OAUTH_INTERACTIVE = "oauth_interactive"
UPLOADED = "uploaded"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# Error message substrings

View File

@ -232,7 +232,7 @@ class Chunker:
logger.warning(
f"Skipping section {section.text} from document "
f"{document.semantic_identifier} due to empty text after cleaning "
f" with link {section_link_text}"
f"with link {section_link_text}"
)
continue

View File

@ -52,7 +52,7 @@ from danswer.server.documents.connector import router as connector_router
from danswer.server.documents.credential import router as credential_router
from danswer.server.documents.document import router as document_router
from danswer.server.documents.indexing import router as indexing_router
from danswer.server.documents.standard_oauth import router as oauth_router
from danswer.server.documents.standard_oauth import router as standard_oauth_router
from danswer.server.features.document_set.api import router as document_set_router
from danswer.server.features.folder.api import router as folder_router
from danswer.server.features.notifications.api import router as notification_router
@ -75,6 +75,7 @@ from danswer.server.manage.search_settings import router as search_settings_rout
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
from danswer.server.oauth import router as oauth_router
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
)
@ -276,6 +277,7 @@ def get_application() -> FastAPI:
application, get_full_openai_assistants_api_router()
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, standard_oauth_router)
include_router_with_global_prefix_prepended(application, api_key_router)
include_router_with_global_prefix_prepended(application, oauth_router)

View File

@ -55,6 +55,9 @@ from danswer.connectors.google_utils.google_kv import verify_csrf
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from danswer.db.connector import create_connector
from danswer.db.connector import delete_connector
from danswer.db.connector import fetch_connector_by_id
@ -311,6 +314,7 @@ def upsert_service_account_credential(
credential_base = build_service_account_creds(
DocumentSource.GOOGLE_DRIVE,
primary_admin_email=service_account_credential_request.google_primary_admin,
name="Service Account (uploaded)",
)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
@ -319,7 +323,9 @@ def upsert_service_account_credential(
delete_service_account_credentials(user, db_session, DocumentSource.GOOGLE_DRIVE)
# `user=None` since this credential is not a personal credential
credential = create_credential(
credential_data=credential_base, user=user, db_session=db_session
credential_data=credential_base,
user=user,
db_session=db_session,
)
return ObjectCreationIdResponse(id=credential.id)
@ -494,6 +500,38 @@ def get_currently_failed_indexing_status(
return indexing_statuses
@router.get("/admin/connector")
def get_connectors_by_credential(
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
credential: int | None = None,
) -> list[ConnectorSnapshot]:
"""Get a list of connectors. Allow filtering by a specific credential id."""
connectors = fetch_connectors(db_session)
filtered_connectors = []
for connector in connectors:
if connector.source == DocumentSource.INGESTION_API:
# don't include INGESTION_API, as it's a system level
# connector not manageable by the user
continue
if credential is not None:
found = False
for cc_pair in connector.credentials:
if credential == cc_pair.credential_id:
found = True
break
if not found:
continue
filtered_connectors.append(ConnectorSnapshot.from_connector_db_model(connector))
return filtered_connectors
@router.get("/admin/connector/indexing-status")
def get_connector_indexing_status(
secondary_index: bool = False,
@ -936,7 +974,12 @@ def gmail_callback(
credential_id = int(credential_id_cookie)
verify_csrf(credential_id, callback.state)
credentials: Credentials | None = update_credential_access_tokens(
callback.code, credential_id, user, db_session, DocumentSource.GMAIL
callback.code,
credential_id,
user,
db_session,
DocumentSource.GMAIL,
GoogleOAuthAuthenticationMethod.UPLOADED,
)
if credentials is None:
raise HTTPException(
@ -962,7 +1005,12 @@ def google_drive_callback(
verify_csrf(credential_id, callback.state)
credentials: Credentials | None = update_credential_access_tokens(
callback.code, credential_id, user, db_session, DocumentSource.GOOGLE_DRIVE
callback.code,
credential_id,
user,
db_session,
DocumentSource.GOOGLE_DRIVE,
GoogleOAuthAuthenticationMethod.UPLOADED,
)
if credentials is None:
raise HTTPException(

View File

@ -9,7 +9,6 @@ from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.db.credentials import alter_credential
from danswer.db.credentials import cleanup_gmail_credentials
from danswer.db.credentials import cleanup_google_drive_credentials
from danswer.db.credentials import create_credential
from danswer.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
from danswer.db.credentials import delete_credential
@ -133,8 +132,6 @@ def create_credential_from_model(
# Temporary fix for empty Google App credentials
if credential_info.source == DocumentSource.GMAIL:
cleanup_gmail_credentials(db_session=db_session)
if credential_info.source == DocumentSource.GOOGLE_DRIVE:
cleanup_google_drive_credentials(db_session=db_session)
credential = create_credential(credential_info, user, db_session)
return ObjectCreationIdResponse(

View File

@ -0,0 +1,629 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_utils.google_auth import get_google_oauth_creds
from danswer.connectors.google_utils.google_auth import sanitize_oauth_credentials
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from danswer.db.credentials import create_credential
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CredentialBase
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.danswer.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
logger = setup_logger()
router = APIRouter(prefix="/oauth")
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.danswer.dev/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class ConfluenceCloudOAuth:
"""work in progress"""
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
# email=user.email, redirect_on_success=redirect_on_success
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
elif connector == DocumentSource.GOOGLE_DRIVE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": SlackOAuth.REDIRECT_URI,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"team_id": team_id,
"authed_user_id": authed_user_id,
"redirect_on_success": session.redirect_on_success,
}
)
# Work in progress
# @router.post("/connector/confluence/callback")
# def handle_confluence_oauth_callback(
# code: str,
# state: str,
# user: User = Depends(current_user),
# db_session: Session = Depends(get_session),
# tenant_id: str | None = Depends(get_current_tenant_id),
# ) -> JSONResponse:
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
# raise HTTPException(
# status_code=500,
# detail="Confluence client ID or client secret is not configured."
# )
# r = get_redis_client(tenant_id=tenant_id)
# # recover the state
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
# # Convert bytes back to a UUID
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
# oauth_uuid_str = str(oauth_uuid)
# r_key = f"da_oauth:{oauth_uuid_str}"
# result = r.get(r_key)
# if not result:
# raise HTTPException(
# status_code=400,
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
# )
# try:
# session = ConfluenceCloudOAuth.parse_session(result)
# # Exchange the authorization code for an access token
# response = requests.post(
# ConfluenceCloudOAuth.TOKEN_URL,
# headers={"Content-Type": "application/x-www-form-urlencoded"},
# data={
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
# "code": code,
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
# },
# )
# response_data = response.json()
# if not response_data.get("ok"):
# raise HTTPException(
# status_code=400,
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
# )
# # Extract token and team information
# access_token: str = response_data.get("access_token")
# team_id: str = response_data.get("team", {}).get("id")
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
# credential_info = CredentialBase(
# credential_json={"slack_bot_token": access_token},
# admin_public=True,
# source=DocumentSource.CONFLUENCE,
# name="Confluence OAuth",
# )
# logger.info(f"Slack access token: {access_token}")
# credential = create_credential(credential_info, user, db_session)
# logger.info(f"new_credential_id={credential.id}")
# except Exception as e:
# return JSONResponse(
# status_code=500,
# content={
# "success": False,
# "message": f"An error occurred during Slack OAuth: {str(e)}",
# },
# )
# finally:
# r.delete(r_key)
# # return the result
# return JSONResponse(
# content={
# "success": True,
# "message": "Slack OAuth completed successfully.",
# "team_id": team_id,
# "authed_user_id": authed_user_id,
# "redirect_on_success": session.redirect_on_success,
# }
# )
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = GoogleDriveOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"redirect_on_success": session.redirect_on_success,
}
)

View File

@ -14,6 +14,9 @@ from danswer.configs.app_configs import SMTP_PORT
from danswer.configs.app_configs import SMTP_SERVER
from danswer.configs.app_configs import SMTP_USER
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from danswer.db.models import User
@ -54,13 +57,20 @@ def mask_string(sensitive_str: str) -> str:
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
masked_creds = {}
for key, val in credential_dict.items():
if not isinstance(val, str):
raise ValueError(
f"Unable to mask credentials of type other than string, cannot process request."
f"Recieved type: {type(val)}"
)
if isinstance(val, str):
# we want to pass the authentication_method field through so the frontend
# can disambiguate credentials created by different methods
if key == DB_CREDENTIALS_AUTHENTICATION_METHOD:
masked_creds[key] = val
else:
masked_creds[key] = mask_string(val)
continue
raise ValueError(
f"Unable to mask credentials of type other than string, cannot process request."
f"Recieved type: {type(val)}"
)
masked_creds[key] = mask_string(val)
return masked_creds

View File

@ -39,3 +39,11 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)

View File

@ -5,6 +5,9 @@ from collections.abc import Callable
import pytest
from danswer.connectors.gmail.connector import GmailConnector
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
@ -14,6 +17,9 @@ from danswer.connectors.google_utils.shared_constants import (
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from tests.load_env_vars import load_env_vars
@ -59,6 +65,7 @@ def google_gmail_oauth_connector_factory() -> Callable[..., GmailConnector]:
credentials_json = {
DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: GoogleOAuthAuthenticationMethod.UPLOADED.value,
}
connector.load_credentials(credentials_json)
return connector
@ -82,6 +89,7 @@ def google_gmail_service_acct_connector_factory() -> Callable[..., GmailConnecto
{
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: refried_json_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: GoogleOAuthAuthenticationMethod.UPLOADED.value,
}
)
return connector

View File

@ -5,6 +5,9 @@ from collections.abc import Callable
import pytest
from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
@ -14,6 +17,9 @@ from danswer.connectors.google_utils.shared_constants import (
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from tests.load_env_vars import load_env_vars
@ -56,7 +62,9 @@ def parse_credentials(env_str: str) -> dict:
@pytest.fixture
def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector]:
def google_drive_oauth_uploaded_connector_factory() -> (
Callable[..., GoogleDriveConnector]
):
def _connector_factory(
primary_admin_email: str,
include_shared_drives: bool,
@ -82,6 +90,7 @@ def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector
credentials_json = {
DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: GoogleOAuthAuthenticationMethod.UPLOADED.value,
}
connector.load_credentials(credentials_json)
return connector
@ -122,6 +131,7 @@ def google_drive_service_acct_connector_factory() -> (
{
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: refried_json_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: GoogleOAuthAuthenticationMethod.UPLOADED.value,
}
)
return connector

View File

@ -35,10 +35,10 @@ from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_
)
def test_include_all(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_include_all")
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
include_my_drives=True,
@ -77,10 +77,10 @@ def test_include_all(
)
def test_include_shared_drives_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_include_shared_drives_only")
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
include_my_drives=False,
@ -117,10 +117,10 @@ def test_include_shared_drives_only(
)
def test_include_my_drives_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_include_my_drives_only")
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=True,
@ -147,11 +147,11 @@ def test_include_my_drives_only(
)
def test_drive_one_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_drive_one_only")
drive_urls = [SHARED_DRIVE_1_URL]
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
include_my_drives=False,
@ -182,12 +182,12 @@ def test_drive_one_only(
)
def test_folder_and_shared_drive(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_folder_and_shared_drive")
drive_urls = [SHARED_DRIVE_1_URL]
folder_urls = [FOLDER_2_URL]
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
include_my_drives=False,
@ -221,7 +221,7 @@ def test_folder_and_shared_drive(
)
def test_folders_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_folders_only")
folder_urls = [
@ -234,7 +234,7 @@ def test_folders_only(
shared_drive_urls = [
FOLDER_1_1_URL,
]
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
include_my_drives=False,
@ -266,13 +266,13 @@ def test_folders_only(
)
def test_personal_folders_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_personal_folders_only")
folder_urls = [
FOLDER_3_URL,
]
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
include_my_drives=False,

View File

@ -15,10 +15,10 @@ from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER
)
def test_google_drive_sections(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
oauth_connector = google_drive_oauth_connector_factory(
oauth_connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=False,

View File

@ -25,10 +25,10 @@ from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FIL
)
def test_all(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_all")
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=TEST_USER_1_EMAIL,
include_files_shared_with_me=True,
include_shared_drives=True,
@ -65,10 +65,10 @@ def test_all(
)
def test_shared_drives_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_drives_only")
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=TEST_USER_1_EMAIL,
include_files_shared_with_me=False,
include_shared_drives=True,
@ -100,10 +100,10 @@ def test_shared_drives_only(
)
def test_shared_with_me_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_with_me_only")
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=TEST_USER_1_EMAIL,
include_files_shared_with_me=True,
include_shared_drives=False,
@ -133,10 +133,10 @@ def test_shared_with_me_only(
)
def test_my_drive_only(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_my_drive_only")
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=TEST_USER_1_EMAIL,
include_files_shared_with_me=False,
include_shared_drives=False,
@ -163,10 +163,10 @@ def test_my_drive_only(
)
def test_shared_my_drive_folder(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_my_drive_folder")
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=TEST_USER_1_EMAIL,
include_files_shared_with_me=False,
include_shared_drives=False,
@ -195,10 +195,10 @@ def test_shared_my_drive_folder(
)
def test_shared_drive_folder(
mock_get_api_key: MagicMock,
google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector],
google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_drive_folder")
connector = google_drive_oauth_connector_factory(
connector = google_drive_oauth_uploaded_connector_factory(
primary_admin_email=TEST_USER_1_EMAIL,
include_files_shared_with_me=False,
include_shared_drives=False,

View File

@ -239,20 +239,20 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
<TableRow>
<TableHead>Time Started</TableHead>
<TableHead>Status</TableHead>
<TableHead>New Doc Cnt</TableHead>
<TableHead>New Documents</TableHead>
<TableHead>
<div className="w-fit">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<span className="cursor-help flex items-center">
Total Doc Cnt
New + Modified Documents
<InfoIcon className="ml-1 w-4 h-4" />
</span>
</TooltipTrigger>
<TooltipContent>
Total number of documents replaced in the index during
this indexing attempt
Total number of documents inserted or updated in the index
during this indexing attempt
</TooltipContent>
</Tooltip>
</TooltipProvider>

View File

@ -48,7 +48,11 @@ import NavigationRow from "./NavigationRow";
import { useRouter } from "next/navigation";
import CardSection from "@/components/admin/CardSection";
import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils";
import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import {
EE_ENABLED,
NEXT_PUBLIC_CLOUD_ENABLED,
TEST_ENV,
} from "@/lib/constants";
import TemporaryLoadingModal from "@/components/TemporaryLoadingModal";
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
export interface AdvancedConfig {
@ -127,7 +131,7 @@ export default function AddConnector({
setCurrentPageUrl(window.location.href);
}
if (EE_ENABLED && NEXT_PUBLIC_CLOUD_ENABLED) {
if (EE_ENABLED && (NEXT_PUBLIC_CLOUD_ENABLED || TEST_ENV)) {
const sourceMetadata = getSourceMetadata(connector);
if (sourceMetadata?.oauthSupported == true) {
setIsAuthorizeVisible(true);
@ -433,9 +437,7 @@ export default function AddConnector({
<CardSection>
<Title className="mb-2 text-lg">Select a credential</Title>
{connector == "google_drive" ? (
<GDriveMain />
) : connector == "gmail" ? (
{connector == "gmail" ? (
<GmailMain />
) : (
<>
@ -488,30 +490,27 @@ export default function AddConnector({
</div>
)}
{/* NOTE: connector will never be google_drive, since the ternary above will
prevent that, but still keeping this here for safety in case the above changes. */}
{(connector as ValidSources) !== "google_drive" &&
createConnectorToggle && (
<Modal
className="max-w-3xl rounded-lg"
onOutsideClick={() => setCreateConnectorToggle(false)}
>
<>
<Title className="mb-2 text-lg">
Create a {getSourceDisplayName(connector)}{" "}
credential
</Title>
<CreateCredential
close
refresh={refresh}
sourceType={connector}
setPopup={setPopup}
onSwitch={onSwap}
onClose={() => setCreateConnectorToggle(false)}
/>
</>
</Modal>
)}
{createConnectorToggle && (
<Modal
className="max-w-3xl rounded-lg"
onOutsideClick={() => setCreateConnectorToggle(false)}
>
<>
<Title className="mb-2 text-lg">
Create a {getSourceDisplayName(connector)}{" "}
credential
</Title>
<CreateCredential
close
refresh={refresh}
sourceType={connector}
setPopup={setPopup}
onSwitch={onSwap}
onClose={() => setCreateConnectorToggle(false)}
/>
</>
</Modal>
)}
</>
)}
</CardSection>

View File

@ -34,6 +34,10 @@ export default function OAuthCallbackPage() {
useEffect(() => {
const handleOAuthCallback = async () => {
// Examples
// connector (url segment)= "google-drive"
// sourceType (for looking up metadata) = "google_drive"
if (!code || !state) {
setStatusMessage("Improperly formed OAuth authorization request.");
setStatusDetails(
@ -43,7 +47,7 @@ export default function OAuthCallbackPage() {
return;
}
if (!connector || !isValidSource(connector)) {
if (!connector) {
setStatusMessage(
`The specified connector source type ${connector} does not exist.`
);
@ -52,7 +56,17 @@ export default function OAuthCallbackPage() {
return;
}
const sourceMetadata = getSourceMetadata(connector as ValidSources);
const sourceType = connector.replaceAll("-", "_");
if (!isValidSource(sourceType)) {
setStatusMessage(
`The specified connector source type ${sourceType} does not exist.`
);
setStatusDetails(`${sourceType} is not a valid source type.`);
setIsError(true);
return;
}
const sourceMetadata = getSourceMetadata(sourceType as ValidSources);
setPageTitle(`Authorize with ${sourceMetadata.displayName}`);
setStatusMessage("Processing...");
@ -60,7 +74,11 @@ export default function OAuthCallbackPage() {
setIsError(false); // Ensure no error state during loading
try {
const response = await handleOAuthAuthorizationResponse(code, state);
const response = await handleOAuthAuthorizationResponse(
connector,
code,
state
);
if (!response) {
throw new Error("Empty response from OAuth server.");

View File

@ -303,52 +303,72 @@ export const DriveJsonUploadSection = ({
};
interface DriveCredentialSectionProps {
googleDrivePublicCredential?: Credential<GoogleDriveCredentialJson>;
googleDrivePublicUploadedCredential?: Credential<GoogleDriveCredentialJson>;
googleDriveServiceAccountCredential?: Credential<GoogleDriveServiceAccountCredentialJson>;
serviceAccountKeyData?: { service_account_email: string };
appCredentialData?: { client_id: string };
setPopup: (popupSpec: PopupSpec | null) => void;
refreshCredentials: () => void;
connectorExists: boolean;
connectorAssociated: boolean;
user: User | null;
}
async function handleRevokeAccess(
connectorAssociated: boolean,
setPopup: (popupSpec: PopupSpec | null) => void,
existingCredential:
| Credential<GoogleDriveCredentialJson>
| Credential<GoogleDriveServiceAccountCredentialJson>,
refreshCredentials: () => void
) {
if (connectorAssociated) {
const message =
"Cannot revoke the Google Drive credential while any connector is still associated with the credential. " +
"Please delete all associated connectors, then try again.";
setPopup({
message: message,
type: "error",
});
return;
}
await adminDeleteCredential(existingCredential.id);
setPopup({
message: "Successfully revoked the Google Drive credential!",
type: "success",
});
refreshCredentials();
}
export const DriveAuthSection = ({
googleDrivePublicCredential,
googleDrivePublicUploadedCredential,
googleDriveServiceAccountCredential,
serviceAccountKeyData,
appCredentialData,
setPopup,
refreshCredentials,
connectorExists,
connectorAssociated, // don't allow revoke if a connector / credential pair is active with the uploaded credential
user,
}: DriveCredentialSectionProps) => {
const router = useRouter();
const existingCredential =
googleDrivePublicCredential || googleDriveServiceAccountCredential;
googleDrivePublicUploadedCredential || googleDriveServiceAccountCredential;
if (existingCredential) {
return (
<>
<p className="mb-2 text-sm">
<i>Existing credential already setup!</i>
<i>Uploaded and authenticated credential already exists!</i>
</p>
<Button
onClick={async () => {
if (connectorExists) {
setPopup({
message:
"Cannot revoke access to Google Drive while any connector is still setup. Please delete all connectors, then try again.",
type: "error",
});
return;
}
await adminDeleteCredential(existingCredential.id);
setPopup({
message: "Successfully revoked access to Google Drive!",
type: "success",
});
refreshCredentials();
handleRevokeAccess(
connectorAssociated,
setPopup,
existingCredential,
refreshCredentials
);
}}
>
Revoke Access
@ -429,6 +449,7 @@ export const DriveAuthSection = ({
onClick={async () => {
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
isAdmin: true,
name: "OAuth (uploaded)",
});
if (authUrl) {
// cookie used by callback to determine where to finally redirect to

View File

@ -1,12 +1,12 @@
"use client";
import React from "react";
import useSWR from "swr";
import React, { useEffect, useState } from "react";
import useSWR, { mutate } from "swr";
import { FetchError, errorHandlingFetcher } from "@/lib/fetcher";
import { ErrorCallout } from "@/components/ErrorCallout";
import { LoadingAnimation } from "@/components/Loading";
import { usePopup } from "@/components/admin/connectors/Popup";
import { ConnectorIndexingStatus } from "@/lib/types";
import { ConnectorIndexingStatus, ValidSources } from "@/lib/types";
import {
usePublicCredentials,
useConnectorCredentialIndexingStatus,
@ -18,12 +18,31 @@ import {
GoogleDriveCredentialJson,
GoogleDriveServiceAccountCredentialJson,
} from "@/lib/connectors/credentials";
import { GoogleDriveConfig } from "@/lib/connectors/connectors";
import {
ConnectorSnapshot,
GoogleDriveConfig,
} from "@/lib/connectors/connectors";
import { useUser } from "@/components/user/UserProvider";
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
import { fetchConnectors } from "@/lib/connector";
const useConnectorsByCredentialId = (credential_id: number | null) => {
let url: string | null = null;
if (credential_id !== null) {
url = `/api/manage/admin/connector?credential=${credential_id}`;
}
const swrResponse = useSWR<ConnectorSnapshot[]>(url, errorHandlingFetcher);
return {
...swrResponse,
refreshConnectorsByCredentialId: () => mutate(url),
};
};
const GDriveMain = ({}: {}) => {
const { isAdmin, user } = useUser();
// tries getting the uploaded credential json
const {
data: appCredentialData,
isLoading: isAppCredentialLoading,
@ -33,6 +52,7 @@ const GDriveMain = ({}: {}) => {
errorHandlingFetcher
);
// tries getting the uploaded service account key
const {
data: serviceAccountKeyData,
isLoading: isServiceAccountKeyLoading,
@ -42,11 +62,7 @@ const GDriveMain = ({}: {}) => {
errorHandlingFetcher
);
const {
data: connectorIndexingStatuses,
isLoading: isConnectorIndexingStatusesLoading,
error: connectorIndexingStatusesError,
} = useConnectorCredentialIndexingStatus();
// gets all public credentials
const {
data: credentialsData,
isLoading: isCredentialsLoading,
@ -54,6 +70,40 @@ const GDriveMain = ({}: {}) => {
refreshCredentials,
} = usePublicCredentials();
// gets all credentials for source type google drive
const {
data: googleDriveCredentials,
isLoading: isGoogleDriveCredentialsLoading,
error: googleDriveCredentialsError,
} = useSWR<Credential<any>[]>(
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive),
errorHandlingFetcher,
{ refreshInterval: 5000 }
);
// filters down to just credentials that were created via upload (there should be only one)
let credential_id = null;
if (googleDriveCredentials) {
const googleDriveUploadedCredentials: Credential<GoogleDriveCredentialJson>[] =
googleDriveCredentials.filter(
(googleDriveCredential) =>
googleDriveCredential.credential_json.authentication_method !==
"oauth_interactive"
);
if (googleDriveUploadedCredentials.length > 0) {
credential_id = googleDriveUploadedCredentials[0].id;
}
}
// retrieves all connectors for that credential id
const {
data: googleDriveConnectors,
isLoading: isGoogleDriveConnectorsLoading,
error: googleDriveConnectorsError,
refreshConnectorsByCredentialId,
} = useConnectorsByCredentialId(credential_id);
const { popup, setPopup } = usePopup();
const appCredentialSuccessfullyFetched =
@ -66,8 +116,9 @@ const GDriveMain = ({}: {}) => {
if (
(!appCredentialSuccessfullyFetched && isAppCredentialLoading) ||
(!serviceAccountKeySuccessfullyFetched && isServiceAccountKeyLoading) ||
(!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) ||
(!credentialsData && isCredentialsLoading)
(!credentialsData && isCredentialsLoading) ||
(!googleDriveCredentials && isGoogleDriveCredentialsLoading) ||
(!googleDriveConnectors && isGoogleDriveConnectorsLoading)
) {
return (
<div className="mx-auto">
@ -80,8 +131,10 @@ const GDriveMain = ({}: {}) => {
return <ErrorCallout errorTitle="Failed to load credentials." />;
}
if (connectorIndexingStatusesError || !connectorIndexingStatuses) {
return <ErrorCallout errorTitle="Failed to load connectors." />;
if (googleDriveCredentialsError || !googleDriveCredentials) {
return (
<ErrorCallout errorTitle="Failed to load Google Drive credentials." />
);
}
if (
@ -93,14 +146,17 @@ const GDriveMain = ({}: {}) => {
);
}
const googleDrivePublicCredential:
// get the actual uploaded oauth or service account credentials
const googleDrivePublicUploadedCredential:
| Credential<GoogleDriveCredentialJson>
| undefined = credentialsData.find(
(credential) =>
credential.credential_json?.google_tokens &&
credential.admin_public &&
credential.source === "google_drive"
credential.source === "google_drive" &&
credential.credential_json.authentication_method !== "oauth_interactive"
);
const googleDriveServiceAccountCredential:
| Credential<GoogleDriveServiceAccountCredentialJson>
| undefined = credentialsData.find(
@ -109,13 +165,18 @@ const GDriveMain = ({}: {}) => {
credential.source === "google_drive"
);
const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus<
GoogleDriveConfig,
GoogleDriveCredentialJson
>[] = connectorIndexingStatuses.filter(
(connectorIndexingStatus) =>
connectorIndexingStatus.connector.source === "google_drive"
);
if (googleDriveConnectorsError) {
return (
<ErrorCallout errorTitle="Failed to load Google Drive associated connectors." />
);
}
let connectorAssociated = false;
if (googleDriveConnectors) {
if (googleDriveConnectors.length > 0) {
connectorAssociated = true;
}
}
return (
<>
@ -138,13 +199,15 @@ const GDriveMain = ({}: {}) => {
<DriveAuthSection
setPopup={setPopup}
refreshCredentials={refreshCredentials}
googleDrivePublicCredential={googleDrivePublicCredential}
googleDrivePublicUploadedCredential={
googleDrivePublicUploadedCredential
}
googleDriveServiceAccountCredential={
googleDriveServiceAccountCredential
}
appCredentialData={appCredentialData}
serviceAccountKeyData={serviceAccountKeyData}
connectorExists={googleDriveConnectorIndexingStatuses.length > 0}
connectorAssociated={connectorAssociated}
user={user}
/>
</>

View File

@ -19,7 +19,7 @@ interface EmbeddingFormContextType {
allowAdvanced: boolean;
setAllowAdvanced: React.Dispatch<React.SetStateAction<boolean>>;
allowCreate: boolean;
setAlowCreate: React.Dispatch<React.SetStateAction<boolean>>;
setAllowCreate: React.Dispatch<React.SetStateAction<boolean>>;
}
const EmbeddingFormContext = createContext<
@ -39,7 +39,7 @@ export const EmbeddingFormProvider: React.FC<{
const [formValues, setFormValues] = useState<Record<string, any>>({});
const [allowAdvanced, setAllowAdvanced] = useState(false);
const [allowCreate, setAlowCreate] = useState(false);
const [allowCreate, setAllowCreate] = useState(false);
const nextFormStep = (values = "") => {
setFormStep((prevStep) => prevStep + 1);
@ -88,7 +88,7 @@ export const EmbeddingFormProvider: React.FC<{
allowAdvanced,
setAllowAdvanced,
allowCreate,
setAlowCreate,
setAllowCreate: setAllowCreate,
};
return (

View File

@ -35,6 +35,9 @@ const CredentialSelectionTable = ({
number | null
>(null);
// rkuo: this appears to merge editableCredentials into credentials so we get a single list
// of credentials to display
// Pretty sure this merging should be done outside of this UI component
const allCredentials = React.useMemo(() => {
const credMap = new Map(editableCredentials.map((cred) => [cred.id, cred]));
credentials.forEach((cred) => {

View File

@ -1,6 +1,10 @@
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { ValidSources } from "./types";
import { Connector, ConnectorBase } from "./connectors/connectors";
import {
Connector,
ConnectorBase,
ConnectorSnapshot,
} from "./connectors/connectors";
async function handleResponse(
response: Response
): Promise<[string | null, any]> {
@ -11,6 +15,18 @@ async function handleResponse(
return [responseJson.detail, null];
}
export async function fetchConnectors(
credential_id: number
): Promise<ConnectorSnapshot[]> {
const url = `/api/manage/admin/connector?credential=${credential_id}`;
const response = await fetch(url);
if (!response.ok) {
throw new Error(`Failed to fetch connectors: ${await response.text()}`);
}
const connectors: ConnectorSnapshot[] = await response.json();
return connectors;
}
export async function createConnector<T>(
connector: ConnectorBase<T>
): Promise<[string | null, Connector<T> | null]> {

View File

@ -1145,6 +1145,20 @@ export interface Connector<T> extends ConnectorBase<T> {
time_updated: string;
}
export interface ConnectorSnapshot {
id: number;
name: string;
source: ValidSources;
input_type: ValidInputTypes;
// connector_specific_config
refresh_freq: number | null;
prune_freq: number | null;
credential_ids: number[];
indexing_start: number | null;
time_created: string;
time_updated: string;
}
export interface WebConfig {
base_url: string;
web_connector_type?: "recursive" | "single" | "sitemap";

View File

@ -60,6 +60,7 @@ export interface GmailCredentialJson {
export interface GoogleDriveCredentialJson {
google_tokens: string;
google_primary_admin: string;
authentication_method?: string;
}
export interface GmailServiceAccountCredentialJson {
@ -70,6 +71,7 @@ export interface GmailServiceAccountCredentialJson {
export interface GoogleDriveServiceAccountCredentialJson {
google_service_account_key: string;
google_primary_admin: string;
authentication_method?: string;
}
export interface SlabCredentialJson {

View File

@ -76,3 +76,5 @@ export const REGISTRATION_URL =
export const SERVER_SIDE_ONLY__CLOUD_ENABLED =
process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true";
export const TEST_ENV = process.env.TEST_ENV?.toLowerCase() === "true";

View File

@ -2,8 +2,10 @@ import { Credential } from "./connectors/credentials";
export const setupGoogleDriveOAuth = async ({
isAdmin,
name,
}: {
isAdmin: boolean;
name: string;
}): Promise<[string | null, string]> => {
const credentialCreationResponse = await fetch("/api/manage/credential", {
method: "POST",
@ -14,6 +16,7 @@ export const setupGoogleDriveOAuth = async ({
admin_public: isAdmin,
credential_json: {},
source: "google_drive",
name: name,
}),
});

View File

@ -1,4 +1,5 @@
import {
OAuthGoogleDriveCallbackResponse,
OAuthPrepareAuthorizationResponse,
OAuthSlackCallbackResponse,
} from "./types";
@ -39,9 +40,25 @@ export async function prepareOAuthAuthorizationRequest(
return data;
}
export async function handleOAuthAuthorizationResponse(
connector: string,
code: string,
state: string
) {
if (connector === "slack") {
return handleOAuthSlackAuthorizationResponse(code, state);
}
if (connector === "google-drive") {
return handleOAuthGoogleDriveAuthorizationResponse(code, state);
}
return;
}
// server side handler to process the oauth redirect callback
// https://api.slack.com/authentication/oauth-v2#exchanging
export async function handleOAuthAuthorizationResponse(
export async function handleOAuthSlackAuthorizationResponse(
code: string,
state: string
): Promise<OAuthSlackCallbackResponse> {
@ -78,3 +95,43 @@ export async function handleOAuthAuthorizationResponse(
const data = (await response.json()) as OAuthSlackCallbackResponse;
return data;
}
// server side handler to process the oauth redirect callback
// https://api.slack.com/authentication/oauth-v2#exchanging
export async function handleOAuthGoogleDriveAuthorizationResponse(
code: string,
state: string
): Promise<OAuthGoogleDriveCallbackResponse> {
const url = `/api/oauth/connector/google-drive/callback?code=${encodeURIComponent(
code
)}&state=${encodeURIComponent(state)}`;
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ code, state }),
});
if (!response.ok) {
let errorDetails = `Failed to handle OAuth authorization response: ${response.status}`;
try {
const responseBody = await response.text(); // Read the body as text
errorDetails += `\nResponse Body: ${responseBody}`;
} catch (err) {
if (err instanceof Error) {
errorDetails += `\nUnable to read response body: ${err.message}`;
} else {
errorDetails += `\nUnable to read response body: Unknown error type`;
}
}
throw new Error(errorDetails);
}
// Parse the JSON response
const data = (await response.json()) as OAuthGoogleDriveCallbackResponse;
return data;
}

View File

@ -90,6 +90,7 @@ export const SOURCE_METADATA_MAP: SourceMap = {
displayName: "Google Drive",
category: SourceCategory.Storage,
docs: "https://docs.danswer.dev/connectors/google_drive/overview",
oauthSupported: true,
},
github: {
icon: GithubIcon,

View File

@ -69,7 +69,11 @@ export interface MinimalUserSnapshot {
email: string;
}
export type ValidInputTypes = "load_state" | "poll" | "event";
export type ValidInputTypes =
| "load_state"
| "poll"
| "event"
| "slim_retrieval";
export type ValidStatuses =
| "success"
| "completed_with_errors"
@ -147,6 +151,12 @@ export interface OAuthSlackCallbackResponse {
redirect_on_success: string;
}
export interface OAuthGoogleDriveCallbackResponse {
success: boolean;
message: string;
redirect_on_success: string;
}
export interface CCPairBasicInfo {
has_successful_run: boolean;
source: ValidSources;
@ -329,6 +339,7 @@ export type ConfigurableSources = Exclude<
export const oauthSupportedSources: ConfigurableSources[] = [
ValidSources.Slack,
ValidSources.GoogleDrive,
];
export type OAuthSupportedSource = (typeof oauthSupportedSources)[number];

View File

@ -43,7 +43,7 @@ test(
);
await expect(page.locator("p.text-text-500")).toHaveText(
"invalid-connector is not a valid source type."
"invalid_connector is not a valid source type."
);
}
);