diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 09d0c6162c..18a6b0b38a 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -81,12 +81,6 @@ OAUTH_CLIENT_SECRET = ( or "" ) -# for future OAuth connector support -# OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "") -# OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "") -# OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "") -# OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "") - USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "") # for basic auth @@ -544,3 +538,5 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE") DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true" + +TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true" diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 1b03c703db..771f9239e9 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -216,8 +216,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): return self._creds def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: - primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] - self._primary_admin_email = primary_admin_email + self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] self._creds, new_creds_dict = get_google_creds( credentials=credentials, diff --git a/backend/danswer/connectors/google_utils/google_auth.py b/backend/danswer/connectors/google_utils/google_auth.py index 8a8c59d6af..6db789a0b1 100644 --- a/backend/danswer/connectors/google_utils/google_auth.py +++ b/backend/danswer/connectors/google_utils/google_auth.py @@ -1,11 +1,14 @@ import json -from typing import cast +from typing import Any from google.auth.transport.requests import Request # type: ignore from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from danswer.configs.constants import DocumentSource +from danswer.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) from danswer.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) @@ -18,14 +21,42 @@ from danswer.connectors.google_utils.shared_constants import ( from danswer.connectors.google_utils.shared_constants import ( GOOGLE_SCOPES, ) +from danswer.connectors.google_utils.shared_constants import ( + GoogleOAuthAuthenticationMethod, +) from danswer.utils.logger import setup_logger +from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID +from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET logger = setup_logger() +def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str: + """we really don't want to be persisting the client id and secret anywhere but the + environment. + + Returns a string of serialized json. + """ + + # strip the client id and secret + oauth_creds_json_str = oauth_creds.to_json() + oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str) + oauth_creds_sanitized_json.pop("client_id", None) + oauth_creds_sanitized_json.pop("client_secret", None) + oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json) + return oauth_creds_sanitized_json_str + + def get_google_oauth_creds( token_json_str: str, source: DocumentSource ) -> OAuthCredentials | None: + """creds_json only needs to contain client_id, client_secret and refresh_token to + refresh the creds. + + expiry and token are optional ... however, if passing in expiry, token + should also be passed in or else we may not return any creds. + (probably a sign we should refactor the function) + """ creds_json = json.loads(token_json_str) creds = OAuthCredentials.from_authorized_user_info( info=creds_json, @@ -41,7 +72,7 @@ def get_google_oauth_creds( logger.notice("Refreshed Google Drive tokens.") return creds except Exception: - logger.exception("Failed to refresh google drive access token due to:") + logger.exception("Failed to refresh google drive access token") return None return None @@ -52,31 +83,72 @@ def get_google_creds( source: DocumentSource, ) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: """Checks for two different types of credentials. - (1) A credential which holds a token acquired via a user going thorough + (1) A credential which holds a token acquired via a user going through the Google OAuth flow. (2) A credential which holds a service account key JSON file, which can then be used to impersonate any user in the workspace. + + Return a tuple where: + The first element is the requested credentials + The second element is a new credentials dict that the caller should write back + to the db. This happens if token rotation occurs while loading credentials. """ oauth_creds = None service_creds = None new_creds_dict = None if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials: # OAUTH - access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]) - oauth_creds = get_google_oauth_creds( - token_json_str=access_token_json_str, source=source + authentication_method: str = credentials.get( + DB_CREDENTIALS_AUTHENTICATION_METHOD, + GoogleOAuthAuthenticationMethod.UPLOADED.value, ) - # tell caller to update token stored in DB if it has changed - # (e.g. the token has been refreshed) - new_creds_json_str = oauth_creds.to_json() if oauth_creds else "" - if new_creds_json_str != access_token_json_str: - new_creds_dict = { - DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str, - DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[ - DB_CREDENTIALS_PRIMARY_ADMIN_KEY - ], - } + credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY] + credentials_dict = json.loads(credentials_dict_str) + + # only send what get_google_oauth_creds needs + authorized_user_info = {} + + # oauth_interactive is sanitized and needs credentials from the environment + if ( + authentication_method + == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value + ): + authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID + authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET + else: + authorized_user_info["client_id"] = credentials_dict["client_id"] + authorized_user_info["client_secret"] = credentials_dict["client_secret"] + + authorized_user_info["refresh_token"] = credentials_dict["refresh_token"] + + authorized_user_info["token"] = credentials_dict["token"] + authorized_user_info["expiry"] = credentials_dict["expiry"] + + token_json_str = json.dumps(authorized_user_info) + oauth_creds = get_google_oauth_creds( + token_json_str=token_json_str, source=source + ) + + # tell caller to update token stored in DB if the refresh token changed + if oauth_creds: + if oauth_creds.refresh_token != authorized_user_info["refresh_token"]: + # if oauth_interactive, sanitize the credentials so they don't get stored in the db + if ( + authentication_method + == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value + ): + oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds) + else: + oauth_creds_json_str = oauth_creds.to_json() + + new_creds_dict = { + DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[ + DB_CREDENTIALS_PRIMARY_ADMIN_KEY + ], + DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method, + } elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials: # SERVICE ACCOUNT service_account_key_json_str = credentials[ diff --git a/backend/danswer/connectors/google_utils/google_kv.py b/backend/danswer/connectors/google_utils/google_kv.py index 7984681ed8..8478deef4a 100644 --- a/backend/danswer/connectors/google_utils/google_kv.py +++ b/backend/danswer/connectors/google_utils/google_kv.py @@ -17,6 +17,9 @@ from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY from danswer.connectors.google_utils.resources import get_drive_service from danswer.connectors.google_utils.resources import get_gmail_service +from danswer.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) from danswer.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) @@ -29,6 +32,9 @@ from danswer.connectors.google_utils.shared_constants import ( from danswer.connectors.google_utils.shared_constants import ( GOOGLE_SCOPES, ) +from danswer.connectors.google_utils.shared_constants import ( + GoogleOAuthAuthenticationMethod, +) from danswer.connectors.google_utils.shared_constants import ( MISSING_SCOPES_ERROR_STR, ) @@ -96,6 +102,7 @@ def update_credential_access_tokens( user: User, db_session: Session, source: DocumentSource, + auth_method: GoogleOAuthAuthenticationMethod, ) -> OAuthCredentials | None: app_credentials = get_google_app_cred(source) flow = InstalledAppFlow.from_client_config( @@ -119,6 +126,7 @@ def update_credential_access_tokens( new_creds_dict = { DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str, DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email, + DB_CREDENTIALS_AUTHENTICATION_METHOD: auth_method.value, } if not update_credential_json(credential_id, new_creds_dict, user, db_session): @@ -129,6 +137,7 @@ def update_credential_access_tokens( def build_service_account_creds( source: DocumentSource, primary_admin_email: str | None = None, + name: str | None = None, ) -> CredentialBase: service_account_key = get_service_account_key(source=source) @@ -138,10 +147,15 @@ def build_service_account_creds( if primary_admin_email: credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email + credential_dict[ + DB_CREDENTIALS_AUTHENTICATION_METHOD + ] = GoogleOAuthAuthenticationMethod.UPLOADED.value + return CredentialBase( credential_json=credential_dict, admin_public=True, source=source, + name=name, ) diff --git a/backend/danswer/connectors/google_utils/shared_constants.py b/backend/danswer/connectors/google_utils/shared_constants.py index ef3c0bb030..aebbab4514 100644 --- a/backend/danswer/connectors/google_utils/shared_constants.py +++ b/backend/danswer/connectors/google_utils/shared_constants.py @@ -1,3 +1,5 @@ +from enum import Enum as PyEnum + from danswer.configs.constants import DocumentSource # NOTE: do not need https://www.googleapis.com/auth/documents.readonly @@ -23,6 +25,19 @@ DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key" # The email saved for both auth types DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin" +# https://developers.google.com/workspace/guides/create-credentials +# Internally defined authentication method type. +# The value must be one of "oauth_interactive" or "uploaded" +# Used to disambiguate whether credentials have already been created via +# certain methods and what actions we allow users to take +DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method" + + +class GoogleOAuthAuthenticationMethod(str, PyEnum): + OAUTH_INTERACTIVE = "oauth_interactive" + UPLOADED = "uploaded" + + USER_FIELDS = "nextPageToken, users(primaryEmail)" # Error message substrings diff --git a/backend/danswer/indexing/chunker.py b/backend/danswer/indexing/chunker.py index 287d3ba2d5..86209624d2 100644 --- a/backend/danswer/indexing/chunker.py +++ b/backend/danswer/indexing/chunker.py @@ -232,7 +232,7 @@ class Chunker: logger.warning( f"Skipping section {section.text} from document " f"{document.semantic_identifier} due to empty text after cleaning " - f" with link {section_link_text}" + f"with link {section_link_text}" ) continue diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 8f9f9ae5f6..49fdfadfbf 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -52,7 +52,7 @@ from danswer.server.documents.connector import router as connector_router from danswer.server.documents.credential import router as credential_router from danswer.server.documents.document import router as document_router from danswer.server.documents.indexing import router as indexing_router -from danswer.server.documents.standard_oauth import router as oauth_router +from danswer.server.documents.standard_oauth import router as standard_oauth_router from danswer.server.features.document_set.api import router as document_set_router from danswer.server.features.folder.api import router as folder_router from danswer.server.features.notifications.api import router as notification_router @@ -75,6 +75,7 @@ from danswer.server.manage.search_settings import router as search_settings_rout from danswer.server.manage.slack_bot import router as slack_bot_management_router from danswer.server.manage.users import router as user_router from danswer.server.middleware.latency_logging import add_latency_logging_middleware +from danswer.server.oauth import router as oauth_router from danswer.server.openai_assistants_api.full_openai_assistants_api import ( get_full_openai_assistants_api_router, ) @@ -276,6 +277,7 @@ def get_application() -> FastAPI: application, get_full_openai_assistants_api_router() ) include_router_with_global_prefix_prepended(application, long_term_logs_router) + include_router_with_global_prefix_prepended(application, standard_oauth_router) include_router_with_global_prefix_prepended(application, api_key_router) include_router_with_global_prefix_prepended(application, oauth_router) diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index e7cf00ba6d..e6afb95e73 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -55,6 +55,9 @@ from danswer.connectors.google_utils.google_kv import verify_csrf from danswer.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_TOKEN_KEY, ) +from danswer.connectors.google_utils.shared_constants import ( + GoogleOAuthAuthenticationMethod, +) from danswer.db.connector import create_connector from danswer.db.connector import delete_connector from danswer.db.connector import fetch_connector_by_id @@ -311,6 +314,7 @@ def upsert_service_account_credential( credential_base = build_service_account_creds( DocumentSource.GOOGLE_DRIVE, primary_admin_email=service_account_credential_request.google_primary_admin, + name="Service Account (uploaded)", ) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -319,7 +323,9 @@ def upsert_service_account_credential( delete_service_account_credentials(user, db_session, DocumentSource.GOOGLE_DRIVE) # `user=None` since this credential is not a personal credential credential = create_credential( - credential_data=credential_base, user=user, db_session=db_session + credential_data=credential_base, + user=user, + db_session=db_session, ) return ObjectCreationIdResponse(id=credential.id) @@ -494,6 +500,38 @@ def get_currently_failed_indexing_status( return indexing_statuses +@router.get("/admin/connector") +def get_connectors_by_credential( + _: User = Depends(current_curator_or_admin_user), + db_session: Session = Depends(get_session), + credential: int | None = None, +) -> list[ConnectorSnapshot]: + """Get a list of connectors. Allow filtering by a specific credential id.""" + + connectors = fetch_connectors(db_session) + + filtered_connectors = [] + for connector in connectors: + if connector.source == DocumentSource.INGESTION_API: + # don't include INGESTION_API, as it's a system level + # connector not manageable by the user + continue + + if credential is not None: + found = False + for cc_pair in connector.credentials: + if credential == cc_pair.credential_id: + found = True + break + + if not found: + continue + + filtered_connectors.append(ConnectorSnapshot.from_connector_db_model(connector)) + + return filtered_connectors + + @router.get("/admin/connector/indexing-status") def get_connector_indexing_status( secondary_index: bool = False, @@ -936,7 +974,12 @@ def gmail_callback( credential_id = int(credential_id_cookie) verify_csrf(credential_id, callback.state) credentials: Credentials | None = update_credential_access_tokens( - callback.code, credential_id, user, db_session, DocumentSource.GMAIL + callback.code, + credential_id, + user, + db_session, + DocumentSource.GMAIL, + GoogleOAuthAuthenticationMethod.UPLOADED, ) if credentials is None: raise HTTPException( @@ -962,7 +1005,12 @@ def google_drive_callback( verify_csrf(credential_id, callback.state) credentials: Credentials | None = update_credential_access_tokens( - callback.code, credential_id, user, db_session, DocumentSource.GOOGLE_DRIVE + callback.code, + credential_id, + user, + db_session, + DocumentSource.GOOGLE_DRIVE, + GoogleOAuthAuthenticationMethod.UPLOADED, ) if credentials is None: raise HTTPException( diff --git a/backend/danswer/server/documents/credential.py b/backend/danswer/server/documents/credential.py index 160664e55c..1cd118cd93 100644 --- a/backend/danswer/server/documents/credential.py +++ b/backend/danswer/server/documents/credential.py @@ -9,7 +9,6 @@ from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user from danswer.db.credentials import alter_credential from danswer.db.credentials import cleanup_gmail_credentials -from danswer.db.credentials import cleanup_google_drive_credentials from danswer.db.credentials import create_credential from danswer.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE from danswer.db.credentials import delete_credential @@ -133,8 +132,6 @@ def create_credential_from_model( # Temporary fix for empty Google App credentials if credential_info.source == DocumentSource.GMAIL: cleanup_gmail_credentials(db_session=db_session) - if credential_info.source == DocumentSource.GOOGLE_DRIVE: - cleanup_google_drive_credentials(db_session=db_session) credential = create_credential(credential_info, user, db_session) return ObjectCreationIdResponse( diff --git a/backend/danswer/server/oauth.py b/backend/danswer/server/oauth.py new file mode 100644 index 0000000000..f126fb60b3 --- /dev/null +++ b/backend/danswer/server/oauth.py @@ -0,0 +1,629 @@ +import base64 +import json +import uuid +from typing import Any +from typing import cast + +import requests +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import DocumentSource +from danswer.connectors.google_utils.google_auth import get_google_oauth_creds +from danswer.connectors.google_utils.google_auth import sanitize_oauth_credentials +from danswer.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) +from danswer.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_DICT_TOKEN_KEY, +) +from danswer.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_PRIMARY_ADMIN_KEY, +) +from danswer.connectors.google_utils.shared_constants import ( + GoogleOAuthAuthenticationMethod, +) +from danswer.db.credentials import create_credential +from danswer.db.engine import get_current_tenant_id +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.redis.redis_pool import get_redis_client +from danswer.server.documents.models import CredentialBase +from danswer.utils.logger import setup_logger +from ee.danswer.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID +from ee.danswer.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET +from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID +from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET +from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_ID +from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET + + +logger = setup_logger() + +router = APIRouter(prefix="/oauth") + + +class SlackOAuth: + # https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth + # Example: https://api.slack.com/authentication/oauth-v2#exchanging + + class OAuthSession(BaseModel): + """Stored in redis to be looked up on callback""" + + email: str + redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + + CLIENT_ID = OAUTH_SLACK_CLIENT_ID + CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET + + TOKEN_URL = "https://slack.com/api/oauth.v2.access" + + # SCOPE is per https://docs.danswer.dev/connectors/slack + BOT_SCOPE = ( + "channels:history," + "channels:read," + "groups:history," + "groups:read," + "channels:join," + "im:history," + "users:read," + "users:read.email," + "usergroups:read" + ) + + REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback" + DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + + @classmethod + def generate_oauth_url(cls, state: str) -> str: + return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) + + @classmethod + def generate_dev_oauth_url(cls, state: str) -> str: + """dev mode workaround for localhost testing + - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https + """ + + return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) + + @classmethod + def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: + url = ( + f"https://slack.com/oauth/v2/authorize" + f"?client_id={cls.CLIENT_ID}" + f"&redirect_uri={redirect_uri}" + f"&scope={cls.BOT_SCOPE}" + f"&state={state}" + ) + return url + + @classmethod + def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: + """Temporary state to store in redis. to be looked up on auth response. + Returns a json string. + """ + session = SlackOAuth.OAuthSession( + email=email, redirect_on_success=redirect_on_success + ) + return session.model_dump_json() + + @classmethod + def parse_session(cls, session_json: str) -> OAuthSession: + session = SlackOAuth.OAuthSession.model_validate_json(session_json) + return session + + +class ConfluenceCloudOAuth: + """work in progress""" + + # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/ + + class OAuthSession(BaseModel): + """Stored in redis to be looked up on callback""" + + email: str + redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + + CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID + CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET + TOKEN_URL = "https://auth.atlassian.com/oauth/token" + + # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/ + CONFLUENCE_OAUTH_SCOPE = ( + "read:confluence-props%20" + "read:confluence-content.all%20" + "read:confluence-content.summary%20" + "read:confluence-content.permission%20" + "read:confluence-user%20" + "read:confluence-groups%20" + "readonly:content.attachment:confluence" + ) + + REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback" + DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + + # eventually for Confluence Data Center + # oauth_url = ( + # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}" + # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}" + # f"&redirect_uri={redirectme_uri}" + # ) + + @classmethod + def generate_oauth_url(cls, state: str) -> str: + return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) + + @classmethod + def generate_dev_oauth_url(cls, state: str) -> str: + """dev mode workaround for localhost testing + - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https + """ + return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) + + @classmethod + def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: + url = ( + "https://auth.atlassian.com/authorize" + f"?audience=api.atlassian.com" + f"&client_id={cls.CLIENT_ID}" + f"&redirect_uri={redirect_uri}" + f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}" + f"&state={state}" + "&response_type=code" + "&prompt=consent" + ) + return url + + @classmethod + def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: + """Temporary state to store in redis. to be looked up on auth response. + Returns a json string. + """ + session = ConfluenceCloudOAuth.OAuthSession( + email=email, redirect_on_success=redirect_on_success + ) + return session.model_dump_json() + + @classmethod + def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession: + session = SlackOAuth.OAuthSession.model_validate_json(session_json) + return session + + +class GoogleDriveOAuth: + # https://developers.google.com/identity/protocols/oauth2 + # https://developers.google.com/identity/protocols/oauth2/web-server + + class OAuthSession(BaseModel): + """Stored in redis to be looked up on callback""" + + email: str + redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + + CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID + CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET + + TOKEN_URL = "https://oauth2.googleapis.com/token" + + # SCOPE is per https://docs.danswer.dev/connectors/google-drive + # TODO: Merge with or use google_utils.GOOGLE_SCOPES + SCOPE = ( + "https://www.googleapis.com/auth/drive.readonly%20" + "https://www.googleapis.com/auth/drive.metadata.readonly%20" + "https://www.googleapis.com/auth/admin.directory.user.readonly%20" + "https://www.googleapis.com/auth/admin.directory.group.readonly" + ) + + REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback" + DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + + @classmethod + def generate_oauth_url(cls, state: str) -> str: + return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) + + @classmethod + def generate_dev_oauth_url(cls, state: str) -> str: + """dev mode workaround for localhost testing + - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https + """ + + return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) + + @classmethod + def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: + # without prompt=consent, a refresh token is only issued the first time the user approves + url = ( + f"https://accounts.google.com/o/oauth2/v2/auth" + f"?client_id={cls.CLIENT_ID}" + f"&redirect_uri={redirect_uri}" + "&response_type=code" + f"&scope={cls.SCOPE}" + "&access_type=offline" + f"&state={state}" + "&prompt=consent" + ) + return url + + @classmethod + def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: + """Temporary state to store in redis. to be looked up on auth response. + Returns a json string. + """ + session = GoogleDriveOAuth.OAuthSession( + email=email, redirect_on_success=redirect_on_success + ) + return session.model_dump_json() + + @classmethod + def parse_session(cls, session_json: str) -> OAuthSession: + session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json) + return session + + +@router.post("/prepare-authorization-request") +def prepare_authorization_request( + connector: DocumentSource, + redirect_on_success: str | None, + user: User = Depends(current_user), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """Used by the frontend to generate the url for the user's browser during auth request. + + Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/ + """ + + # create random oauth state param for security and to retrieve user data later + oauth_uuid = uuid.uuid4() + oauth_uuid_str = str(oauth_uuid) + + # urlsafe b64 encode the uuid for the oauth url + oauth_state = ( + base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8") + ) + + if connector == DocumentSource.SLACK: + oauth_url = SlackOAuth.generate_oauth_url(oauth_state) + session = SlackOAuth.session_dump_json( + email=user.email, redirect_on_success=redirect_on_success + ) + # elif connector == DocumentSource.CONFLUENCE: + # oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state) + # session = ConfluenceCloudOAuth.session_dump_json( + # email=user.email, redirect_on_success=redirect_on_success + # ) + # elif connector == DocumentSource.JIRA: + # oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state) + elif connector == DocumentSource.GOOGLE_DRIVE: + oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state) + session = GoogleDriveOAuth.session_dump_json( + email=user.email, redirect_on_success=redirect_on_success + ) + else: + oauth_url = None + + if not oauth_url: + raise HTTPException( + status_code=404, + detail=f"The document source type {connector} does not have OAuth implemented", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # store important session state to retrieve when the user is redirected back + # 10 min is the max we want an oauth flow to be valid + r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600) + + return JSONResponse(content={"url": oauth_url}) + + +@router.post("/connector/slack/callback") +def handle_slack_oauth_callback( + code: str, + state: str, + user: User = Depends(current_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET: + raise HTTPException( + status_code=500, + detail="Slack client ID or client secret is not configured.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # recover the state + padded_state = state + "=" * ( + -len(state) % 4 + ) # Add padding back (Base64 decoding requires padding) + uuid_bytes = base64.urlsafe_b64decode( + padded_state + ) # Decode the Base64 string back to bytes + + # Convert bytes back to a UUID + oauth_uuid = uuid.UUID(bytes=uuid_bytes) + oauth_uuid_str = str(oauth_uuid) + + r_key = f"da_oauth:{oauth_uuid_str}" + + session_json_bytes = cast(bytes, r.get(r_key)) + if not session_json_bytes: + raise HTTPException( + status_code=400, + detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}", + ) + + session_json = session_json_bytes.decode("utf-8") + try: + session = SlackOAuth.parse_session(session_json) + + # Exchange the authorization code for an access token + response = requests.post( + SlackOAuth.TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": SlackOAuth.CLIENT_ID, + "client_secret": SlackOAuth.CLIENT_SECRET, + "code": code, + "redirect_uri": SlackOAuth.REDIRECT_URI, + }, + ) + + response_data = response.json() + + if not response_data.get("ok"): + raise HTTPException( + status_code=400, + detail=f"Slack OAuth failed: {response_data.get('error')}", + ) + + # Extract token and team information + access_token: str = response_data.get("access_token") + team_id: str = response_data.get("team", {}).get("id") + authed_user_id: str = response_data.get("authed_user", {}).get("id") + + credential_info = CredentialBase( + credential_json={"slack_bot_token": access_token}, + admin_public=True, + source=DocumentSource.SLACK, + name="Slack OAuth", + ) + + create_credential(credential_info, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Slack OAuth: {str(e)}", + }, + ) + finally: + r.delete(r_key) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Slack OAuth completed successfully.", + "team_id": team_id, + "authed_user_id": authed_user_id, + "redirect_on_success": session.redirect_on_success, + } + ) + + +# Work in progress +# @router.post("/connector/confluence/callback") +# def handle_confluence_oauth_callback( +# code: str, +# state: str, +# user: User = Depends(current_user), +# db_session: Session = Depends(get_session), +# tenant_id: str | None = Depends(get_current_tenant_id), +# ) -> JSONResponse: +# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET: +# raise HTTPException( +# status_code=500, +# detail="Confluence client ID or client secret is not configured." +# ) + +# r = get_redis_client(tenant_id=tenant_id) + +# # recover the state +# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding) +# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes + +# # Convert bytes back to a UUID +# oauth_uuid = uuid.UUID(bytes=uuid_bytes) +# oauth_uuid_str = str(oauth_uuid) + +# r_key = f"da_oauth:{oauth_uuid_str}" + +# result = r.get(r_key) +# if not result: +# raise HTTPException( +# status_code=400, +# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}" +# ) + +# try: +# session = ConfluenceCloudOAuth.parse_session(result) + +# # Exchange the authorization code for an access token +# response = requests.post( +# ConfluenceCloudOAuth.TOKEN_URL, +# headers={"Content-Type": "application/x-www-form-urlencoded"}, +# data={ +# "client_id": ConfluenceCloudOAuth.CLIENT_ID, +# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET, +# "code": code, +# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI, +# }, +# ) + +# response_data = response.json() + +# if not response_data.get("ok"): +# raise HTTPException( +# status_code=400, +# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}" +# ) + +# # Extract token and team information +# access_token: str = response_data.get("access_token") +# team_id: str = response_data.get("team", {}).get("id") +# authed_user_id: str = response_data.get("authed_user", {}).get("id") + +# credential_info = CredentialBase( +# credential_json={"slack_bot_token": access_token}, +# admin_public=True, +# source=DocumentSource.CONFLUENCE, +# name="Confluence OAuth", +# ) + +# logger.info(f"Slack access token: {access_token}") + +# credential = create_credential(credential_info, user, db_session) + +# logger.info(f"new_credential_id={credential.id}") +# except Exception as e: +# return JSONResponse( +# status_code=500, +# content={ +# "success": False, +# "message": f"An error occurred during Slack OAuth: {str(e)}", +# }, +# ) +# finally: +# r.delete(r_key) + +# # return the result +# return JSONResponse( +# content={ +# "success": True, +# "message": "Slack OAuth completed successfully.", +# "team_id": team_id, +# "authed_user_id": authed_user_id, +# "redirect_on_success": session.redirect_on_success, +# } +# ) + + +@router.post("/connector/google-drive/callback") +def handle_google_drive_oauth_callback( + code: str, + state: str, + user: User = Depends(current_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET: + raise HTTPException( + status_code=500, + detail="Google Drive client ID or client secret is not configured.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # recover the state + padded_state = state + "=" * ( + -len(state) % 4 + ) # Add padding back (Base64 decoding requires padding) + uuid_bytes = base64.urlsafe_b64decode( + padded_state + ) # Decode the Base64 string back to bytes + + # Convert bytes back to a UUID + oauth_uuid = uuid.UUID(bytes=uuid_bytes) + oauth_uuid_str = str(oauth_uuid) + + r_key = f"da_oauth:{oauth_uuid_str}" + + session_json_bytes = cast(bytes, r.get(r_key)) + if not session_json_bytes: + raise HTTPException( + status_code=400, + detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}", + ) + + session_json = session_json_bytes.decode("utf-8") + try: + session = GoogleDriveOAuth.parse_session(session_json) + + # Exchange the authorization code for an access token + response = requests.post( + GoogleDriveOAuth.TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": GoogleDriveOAuth.CLIENT_ID, + "client_secret": GoogleDriveOAuth.CLIENT_SECRET, + "code": code, + "redirect_uri": GoogleDriveOAuth.REDIRECT_URI, + "grant_type": "authorization_code", + }, + ) + + response.raise_for_status() + + authorization_response: dict[str, Any] = response.json() + + # the connector wants us to store the json in its authorized_user_info format + # returned from OAuthCredentials.get_authorized_user_info(). + # So refresh immediately via get_google_oauth_creds with the params filled in + # from fields in authorization_response to get the json we need + authorized_user_info = {} + authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID + authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET + authorized_user_info["refresh_token"] = authorization_response["refresh_token"] + + token_json_str = json.dumps(authorized_user_info) + oauth_creds = get_google_oauth_creds( + token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE + ) + if not oauth_creds: + raise RuntimeError("get_google_oauth_creds returned None.") + + # save off the credentials + oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds) + + credential_dict: dict[str, str] = {} + credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str + credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email + credential_dict[ + DB_CREDENTIALS_AUTHENTICATION_METHOD + ] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value + + credential_info = CredentialBase( + credential_json=credential_dict, + admin_public=True, + source=DocumentSource.GOOGLE_DRIVE, + name="OAuth (interactive)", + ) + + create_credential(credential_info, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Google Drive OAuth: {str(e)}", + }, + ) + finally: + r.delete(r_key) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Google Drive OAuth completed successfully.", + "redirect_on_success": session.redirect_on_success, + } + ) diff --git a/backend/danswer/server/utils.py b/backend/danswer/server/utils.py index f59066f9c7..4170faf297 100644 --- a/backend/danswer/server/utils.py +++ b/backend/danswer/server/utils.py @@ -14,6 +14,9 @@ from danswer.configs.app_configs import SMTP_PORT from danswer.configs.app_configs import SMTP_SERVER from danswer.configs.app_configs import SMTP_USER from danswer.configs.app_configs import WEB_DOMAIN +from danswer.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) from danswer.db.models import User @@ -54,13 +57,20 @@ def mask_string(sensitive_str: str) -> str: def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]: masked_creds = {} for key, val in credential_dict.items(): - if not isinstance(val, str): - raise ValueError( - f"Unable to mask credentials of type other than string, cannot process request." - f"Recieved type: {type(val)}" - ) + if isinstance(val, str): + # we want to pass the authentication_method field through so the frontend + # can disambiguate credentials created by different methods + if key == DB_CREDENTIALS_AUTHENTICATION_METHOD: + masked_creds[key] = val + else: + masked_creds[key] = mask_string(val) + continue + + raise ValueError( + f"Unable to mask credentials of type other than string, cannot process request." + f"Recieved type: {type(val)}" + ) - masked_creds[key] = mask_string(val) return masked_creds diff --git a/backend/ee/danswer/configs/app_configs.py b/backend/ee/danswer/configs/app_configs.py index 057922dc24..a4753a002d 100644 --- a/backend/ee/danswer/configs/app_configs.py +++ b/backend/ee/danswer/configs/app_configs.py @@ -39,3 +39,11 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key") OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "") OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "") +OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "") +OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "") +OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "") +OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "") +OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "") +OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get( + "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" +) diff --git a/backend/tests/daily/connectors/gmail/conftest.py b/backend/tests/daily/connectors/gmail/conftest.py index 5010d0b513..8ed2f01eb9 100644 --- a/backend/tests/daily/connectors/gmail/conftest.py +++ b/backend/tests/daily/connectors/gmail/conftest.py @@ -5,6 +5,9 @@ from collections.abc import Callable import pytest from danswer.connectors.gmail.connector import GmailConnector +from danswer.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) from danswer.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) @@ -14,6 +17,9 @@ from danswer.connectors.google_utils.shared_constants import ( from danswer.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) +from danswer.connectors.google_utils.shared_constants import ( + GoogleOAuthAuthenticationMethod, +) from tests.load_env_vars import load_env_vars @@ -59,6 +65,7 @@ def google_gmail_oauth_connector_factory() -> Callable[..., GmailConnector]: credentials_json = { DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string, DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email, + DB_CREDENTIALS_AUTHENTICATION_METHOD: GoogleOAuthAuthenticationMethod.UPLOADED.value, } connector.load_credentials(credentials_json) return connector @@ -82,6 +89,7 @@ def google_gmail_service_acct_connector_factory() -> Callable[..., GmailConnecto { DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: refried_json_string, DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email, + DB_CREDENTIALS_AUTHENTICATION_METHOD: GoogleOAuthAuthenticationMethod.UPLOADED.value, } ) return connector diff --git a/backend/tests/daily/connectors/google_drive/conftest.py b/backend/tests/daily/connectors/google_drive/conftest.py index 4f525b2459..8de8784f15 100644 --- a/backend/tests/daily/connectors/google_drive/conftest.py +++ b/backend/tests/daily/connectors/google_drive/conftest.py @@ -5,6 +5,9 @@ from collections.abc import Callable import pytest from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) from danswer.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) @@ -14,6 +17,9 @@ from danswer.connectors.google_utils.shared_constants import ( from danswer.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) +from danswer.connectors.google_utils.shared_constants import ( + GoogleOAuthAuthenticationMethod, +) from tests.load_env_vars import load_env_vars @@ -56,7 +62,9 @@ def parse_credentials(env_str: str) -> dict: @pytest.fixture -def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector]: +def google_drive_oauth_uploaded_connector_factory() -> ( + Callable[..., GoogleDriveConnector] +): def _connector_factory( primary_admin_email: str, include_shared_drives: bool, @@ -82,6 +90,7 @@ def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector credentials_json = { DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string, DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email, + DB_CREDENTIALS_AUTHENTICATION_METHOD: GoogleOAuthAuthenticationMethod.UPLOADED.value, } connector.load_credentials(credentials_json) return connector @@ -122,6 +131,7 @@ def google_drive_service_acct_connector_factory() -> ( { DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: refried_json_string, DB_CREDENTIALS_PRIMARY_ADMIN_KEY: primary_admin_email, + DB_CREDENTIALS_AUTHENTICATION_METHOD: GoogleOAuthAuthenticationMethod.UPLOADED.value, } ) return connector diff --git a/backend/tests/daily/connectors/google_drive/test_admin_oauth.py b/backend/tests/daily/connectors/google_drive/test_admin_oauth.py index 74625ed0b2..117d294de6 100644 --- a/backend/tests/daily/connectors/google_drive/test_admin_oauth.py +++ b/backend/tests/daily/connectors/google_drive/test_admin_oauth.py @@ -35,10 +35,10 @@ from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_ ) def test_include_all( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_include_all") - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=True, @@ -77,10 +77,10 @@ def test_include_all( ) def test_include_shared_drives_only( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_include_shared_drives_only") - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=False, @@ -117,10 +117,10 @@ def test_include_shared_drives_only( ) def test_include_my_drives_only( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_include_my_drives_only") - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=ADMIN_EMAIL, include_shared_drives=False, include_my_drives=True, @@ -147,11 +147,11 @@ def test_include_my_drives_only( ) def test_drive_one_only( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_drive_one_only") drive_urls = [SHARED_DRIVE_1_URL] - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=False, @@ -182,12 +182,12 @@ def test_drive_one_only( ) def test_folder_and_shared_drive( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_folder_and_shared_drive") drive_urls = [SHARED_DRIVE_1_URL] folder_urls = [FOLDER_2_URL] - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=False, @@ -221,7 +221,7 @@ def test_folder_and_shared_drive( ) def test_folders_only( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_folders_only") folder_urls = [ @@ -234,7 +234,7 @@ def test_folders_only( shared_drive_urls = [ FOLDER_1_1_URL, ] - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=False, @@ -266,13 +266,13 @@ def test_folders_only( ) def test_personal_folders_only( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_personal_folders_only") folder_urls = [ FOLDER_3_URL, ] - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=False, diff --git a/backend/tests/daily/connectors/google_drive/test_sections.py b/backend/tests/daily/connectors/google_drive/test_sections.py index 989bf9e9e7..1ef05729ff 100644 --- a/backend/tests/daily/connectors/google_drive/test_sections.py +++ b/backend/tests/daily/connectors/google_drive/test_sections.py @@ -15,10 +15,10 @@ from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER ) def test_google_drive_sections( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: - oauth_connector = google_drive_oauth_connector_factory( + oauth_connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=ADMIN_EMAIL, include_shared_drives=False, include_my_drives=False, diff --git a/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py b/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py index 9f17353626..2278259acd 100644 --- a/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py +++ b/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py @@ -25,10 +25,10 @@ from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FIL ) def test_all( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_all") - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=TEST_USER_1_EMAIL, include_files_shared_with_me=True, include_shared_drives=True, @@ -65,10 +65,10 @@ def test_all( ) def test_shared_drives_only( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_shared_drives_only") - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=TEST_USER_1_EMAIL, include_files_shared_with_me=False, include_shared_drives=True, @@ -100,10 +100,10 @@ def test_shared_drives_only( ) def test_shared_with_me_only( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_shared_with_me_only") - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=TEST_USER_1_EMAIL, include_files_shared_with_me=True, include_shared_drives=False, @@ -133,10 +133,10 @@ def test_shared_with_me_only( ) def test_my_drive_only( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_my_drive_only") - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=TEST_USER_1_EMAIL, include_files_shared_with_me=False, include_shared_drives=False, @@ -163,10 +163,10 @@ def test_my_drive_only( ) def test_shared_my_drive_folder( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_shared_my_drive_folder") - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=TEST_USER_1_EMAIL, include_files_shared_with_me=False, include_shared_drives=False, @@ -195,10 +195,10 @@ def test_shared_my_drive_folder( ) def test_shared_drive_folder( mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_oauth_uploaded_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning test_shared_drive_folder") - connector = google_drive_oauth_connector_factory( + connector = google_drive_oauth_uploaded_connector_factory( primary_admin_email=TEST_USER_1_EMAIL, include_files_shared_with_me=False, include_shared_drives=False, diff --git a/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx b/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx index 4b56d277bb..7bc88240cb 100644 --- a/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx +++ b/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx @@ -239,20 +239,20 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) { Time Started Status - New Doc Cnt + New Documents
- Total Doc Cnt + New + Modified Documents - Total number of documents replaced in the index during - this indexing attempt + Total number of documents inserted or updated in the index + during this indexing attempt diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index 31566f0343..6beb23f2ea 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -48,7 +48,11 @@ import NavigationRow from "./NavigationRow"; import { useRouter } from "next/navigation"; import CardSection from "@/components/admin/CardSection"; import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils"; -import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants"; +import { + EE_ENABLED, + NEXT_PUBLIC_CLOUD_ENABLED, + TEST_ENV, +} from "@/lib/constants"; import TemporaryLoadingModal from "@/components/TemporaryLoadingModal"; import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth"; export interface AdvancedConfig { @@ -127,7 +131,7 @@ export default function AddConnector({ setCurrentPageUrl(window.location.href); } - if (EE_ENABLED && NEXT_PUBLIC_CLOUD_ENABLED) { + if (EE_ENABLED && (NEXT_PUBLIC_CLOUD_ENABLED || TEST_ENV)) { const sourceMetadata = getSourceMetadata(connector); if (sourceMetadata?.oauthSupported == true) { setIsAuthorizeVisible(true); @@ -433,9 +437,7 @@ export default function AddConnector({ Select a credential - {connector == "google_drive" ? ( - - ) : connector == "gmail" ? ( + {connector == "gmail" ? ( ) : ( <> @@ -488,30 +490,27 @@ export default function AddConnector({
)} - {/* NOTE: connector will never be google_drive, since the ternary above will - prevent that, but still keeping this here for safety in case the above changes. */} - {(connector as ValidSources) !== "google_drive" && - createConnectorToggle && ( - setCreateConnectorToggle(false)} - > - <> - - Create a {getSourceDisplayName(connector)}{" "} - credential - - setCreateConnectorToggle(false)} - /> - - - )} + {createConnectorToggle && ( + setCreateConnectorToggle(false)} + > + <> + + Create a {getSourceDisplayName(connector)}{" "} + credential + + setCreateConnectorToggle(false)} + /> + + + )} )} diff --git a/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx b/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx index 30fd9e0747..8032c0d7b4 100644 --- a/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx +++ b/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx @@ -34,6 +34,10 @@ export default function OAuthCallbackPage() { useEffect(() => { const handleOAuthCallback = async () => { + // Examples + // connector (url segment)= "google-drive" + // sourceType (for looking up metadata) = "google_drive" + if (!code || !state) { setStatusMessage("Improperly formed OAuth authorization request."); setStatusDetails( @@ -43,7 +47,7 @@ export default function OAuthCallbackPage() { return; } - if (!connector || !isValidSource(connector)) { + if (!connector) { setStatusMessage( `The specified connector source type ${connector} does not exist.` ); @@ -52,7 +56,17 @@ export default function OAuthCallbackPage() { return; } - const sourceMetadata = getSourceMetadata(connector as ValidSources); + const sourceType = connector.replaceAll("-", "_"); + if (!isValidSource(sourceType)) { + setStatusMessage( + `The specified connector source type ${sourceType} does not exist.` + ); + setStatusDetails(`${sourceType} is not a valid source type.`); + setIsError(true); + return; + } + + const sourceMetadata = getSourceMetadata(sourceType as ValidSources); setPageTitle(`Authorize with ${sourceMetadata.displayName}`); setStatusMessage("Processing..."); @@ -60,7 +74,11 @@ export default function OAuthCallbackPage() { setIsError(false); // Ensure no error state during loading try { - const response = await handleOAuthAuthorizationResponse(code, state); + const response = await handleOAuthAuthorizationResponse( + connector, + code, + state + ); if (!response) { throw new Error("Empty response from OAuth server."); diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx index a68296f6f6..f378121964 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx @@ -303,52 +303,72 @@ export const DriveJsonUploadSection = ({ }; interface DriveCredentialSectionProps { - googleDrivePublicCredential?: Credential; + googleDrivePublicUploadedCredential?: Credential; googleDriveServiceAccountCredential?: Credential; serviceAccountKeyData?: { service_account_email: string }; appCredentialData?: { client_id: string }; setPopup: (popupSpec: PopupSpec | null) => void; refreshCredentials: () => void; - connectorExists: boolean; + connectorAssociated: boolean; user: User | null; } +async function handleRevokeAccess( + connectorAssociated: boolean, + setPopup: (popupSpec: PopupSpec | null) => void, + existingCredential: + | Credential + | Credential, + refreshCredentials: () => void +) { + if (connectorAssociated) { + const message = + "Cannot revoke the Google Drive credential while any connector is still associated with the credential. " + + "Please delete all associated connectors, then try again."; + setPopup({ + message: message, + type: "error", + }); + return; + } + + await adminDeleteCredential(existingCredential.id); + setPopup({ + message: "Successfully revoked the Google Drive credential!", + type: "success", + }); + + refreshCredentials(); +} + export const DriveAuthSection = ({ - googleDrivePublicCredential, + googleDrivePublicUploadedCredential, googleDriveServiceAccountCredential, serviceAccountKeyData, appCredentialData, setPopup, refreshCredentials, - connectorExists, + connectorAssociated, // don't allow revoke if a connector / credential pair is active with the uploaded credential user, }: DriveCredentialSectionProps) => { const router = useRouter(); const existingCredential = - googleDrivePublicCredential || googleDriveServiceAccountCredential; + googleDrivePublicUploadedCredential || googleDriveServiceAccountCredential; if (existingCredential) { return ( <>

- Existing credential already setup! + Uploaded and authenticated credential already exists!