import json from typing import Any from google.auth.transport.requests import Request # type: ignore from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET from onyx.configs.constants import DocumentSource from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_AUTHENTICATION_METHOD, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_TOKEN_KEY, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) from onyx.connectors.google_utils.shared_constants import ( GOOGLE_SCOPES, ) from onyx.connectors.google_utils.shared_constants import ( GoogleOAuthAuthenticationMethod, ) from onyx.utils.logger import setup_logger logger = setup_logger() def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str: """we really don't want to be persisting the client id and secret anywhere but the environment. Returns a string of serialized json. """ # strip the client id and secret oauth_creds_json_str = oauth_creds.to_json() oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str) oauth_creds_sanitized_json.pop("client_id", None) oauth_creds_sanitized_json.pop("client_secret", None) oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json) return oauth_creds_sanitized_json_str def get_google_oauth_creds( token_json_str: str, source: DocumentSource ) -> OAuthCredentials | None: """creds_json only needs to contain client_id, client_secret and refresh_token to refresh the creds. expiry and token are optional ... however, if passing in expiry, token should also be passed in or else we may not return any creds. (probably a sign we should refactor the function) """ creds_json = json.loads(token_json_str) creds = OAuthCredentials.from_authorized_user_info( info=creds_json, scopes=GOOGLE_SCOPES[source], ) if creds.valid: return creds if creds.expired and creds.refresh_token: try: creds.refresh(Request()) if creds.valid: logger.notice("Refreshed Google Drive tokens.") return creds except Exception: logger.exception("Failed to refresh google drive access token") return None return None def get_google_creds( credentials: dict[str, str], source: DocumentSource, ) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: """Checks for two different types of credentials. (1) A credential which holds a token acquired via a user going through the Google OAuth flow. (2) A credential which holds a service account key JSON file, which can then be used to impersonate any user in the workspace. Return a tuple where: The first element is the requested credentials The second element is a new credentials dict that the caller should write back to the db. This happens if token rotation occurs while loading credentials. """ oauth_creds = None service_creds = None new_creds_dict = None if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials: # OAUTH authentication_method: str = credentials.get( DB_CREDENTIALS_AUTHENTICATION_METHOD, GoogleOAuthAuthenticationMethod.UPLOADED.value, ) credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY] credentials_dict = json.loads(credentials_dict_str) # only send what get_google_oauth_creds needs authorized_user_info = {} # oauth_interactive is sanitized and needs credentials from the environment if ( authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value ): authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET else: authorized_user_info["client_id"] = credentials_dict["client_id"] authorized_user_info["client_secret"] = credentials_dict["client_secret"] authorized_user_info["refresh_token"] = credentials_dict["refresh_token"] authorized_user_info["token"] = credentials_dict["token"] authorized_user_info["expiry"] = credentials_dict["expiry"] token_json_str = json.dumps(authorized_user_info) oauth_creds = get_google_oauth_creds( token_json_str=token_json_str, source=source ) # tell caller to update token stored in DB if the refresh token changed if oauth_creds: if oauth_creds.refresh_token != authorized_user_info["refresh_token"]: # if oauth_interactive, sanitize the credentials so they don't get stored in the db if ( authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value ): oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds) else: oauth_creds_json_str = oauth_creds.to_json() new_creds_dict = { DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str, DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[ DB_CREDENTIALS_PRIMARY_ADMIN_KEY ], DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method, } elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials: # SERVICE ACCOUNT service_account_key_json_str = credentials[ DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY ] service_account_key = json.loads(service_account_key_json_str) service_creds = ServiceAccountCredentials.from_service_account_info( service_account_key, scopes=GOOGLE_SCOPES[source] ) if not service_creds.valid or not service_creds.expired: service_creds.refresh(Request()) if not service_creds.valid: raise PermissionError( f"Unable to access {source} - service account credentials are invalid." ) creds: ServiceAccountCredentials | OAuthCredentials | None = ( oauth_creds or service_creds ) if creds is None: raise PermissionError( f"Unable to access {source} - unknown credential structure." ) return creds, new_creds_dict