mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-12 21:00:07 +02:00
180 lines
6.8 KiB
Python
180 lines
6.8 KiB
Python
import json
|
|
from typing import Any
|
|
|
|
from google.auth.transport.requests import Request # type: ignore
|
|
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
|
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
|
|
|
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
|
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
|
from onyx.configs.constants import DocumentSource
|
|
from onyx.connectors.google_utils.shared_constants import (
|
|
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
|
)
|
|
from onyx.connectors.google_utils.shared_constants import (
|
|
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
|
)
|
|
from onyx.connectors.google_utils.shared_constants import (
|
|
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
|
)
|
|
from onyx.connectors.google_utils.shared_constants import (
|
|
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
|
)
|
|
from onyx.connectors.google_utils.shared_constants import (
|
|
GOOGLE_SCOPES,
|
|
)
|
|
from onyx.connectors.google_utils.shared_constants import (
|
|
GoogleOAuthAuthenticationMethod,
|
|
)
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
|
|
"""we really don't want to be persisting the client id and secret anywhere but the
|
|
environment.
|
|
|
|
Returns a string of serialized json.
|
|
"""
|
|
|
|
# strip the client id and secret
|
|
oauth_creds_json_str = oauth_creds.to_json()
|
|
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
|
|
oauth_creds_sanitized_json.pop("client_id", None)
|
|
oauth_creds_sanitized_json.pop("client_secret", None)
|
|
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
|
|
return oauth_creds_sanitized_json_str
|
|
|
|
|
|
def get_google_oauth_creds(
|
|
token_json_str: str, source: DocumentSource
|
|
) -> OAuthCredentials | None:
|
|
"""creds_json only needs to contain client_id, client_secret and refresh_token to
|
|
refresh the creds.
|
|
|
|
expiry and token are optional ... however, if passing in expiry, token
|
|
should also be passed in or else we may not return any creds.
|
|
(probably a sign we should refactor the function)
|
|
"""
|
|
creds_json = json.loads(token_json_str)
|
|
creds = OAuthCredentials.from_authorized_user_info(
|
|
info=creds_json,
|
|
scopes=GOOGLE_SCOPES[source],
|
|
)
|
|
if creds.valid:
|
|
return creds
|
|
|
|
if creds.expired and creds.refresh_token:
|
|
try:
|
|
creds.refresh(Request())
|
|
if creds.valid:
|
|
logger.notice("Refreshed Google Drive tokens.")
|
|
return creds
|
|
except Exception:
|
|
logger.exception("Failed to refresh google drive access token")
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
def get_google_creds(
|
|
credentials: dict[str, str],
|
|
source: DocumentSource,
|
|
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
|
"""Checks for two different types of credentials.
|
|
(1) A credential which holds a token acquired via a user going through
|
|
the Google OAuth flow.
|
|
(2) A credential which holds a service account key JSON file, which
|
|
can then be used to impersonate any user in the workspace.
|
|
|
|
Return a tuple where:
|
|
The first element is the requested credentials
|
|
The second element is a new credentials dict that the caller should write back
|
|
to the db. This happens if token rotation occurs while loading credentials.
|
|
"""
|
|
oauth_creds = None
|
|
service_creds = None
|
|
new_creds_dict = None
|
|
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
|
# OAUTH
|
|
authentication_method: str = credentials.get(
|
|
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
|
GoogleOAuthAuthenticationMethod.UPLOADED.value,
|
|
)
|
|
|
|
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
|
credentials_dict = json.loads(credentials_dict_str)
|
|
|
|
# only send what get_google_oauth_creds needs
|
|
authorized_user_info = {}
|
|
|
|
# oauth_interactive is sanitized and needs credentials from the environment
|
|
if (
|
|
authentication_method
|
|
== GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
|
):
|
|
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
|
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
|
else:
|
|
authorized_user_info["client_id"] = credentials_dict["client_id"]
|
|
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
|
|
|
|
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
|
|
|
|
authorized_user_info["token"] = credentials_dict["token"]
|
|
authorized_user_info["expiry"] = credentials_dict["expiry"]
|
|
|
|
token_json_str = json.dumps(authorized_user_info)
|
|
oauth_creds = get_google_oauth_creds(
|
|
token_json_str=token_json_str, source=source
|
|
)
|
|
|
|
# tell caller to update token stored in DB if the refresh token changed
|
|
if oauth_creds:
|
|
if oauth_creds.refresh_token != authorized_user_info["refresh_token"]:
|
|
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
|
|
if (
|
|
authentication_method
|
|
== GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
|
):
|
|
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
|
|
else:
|
|
oauth_creds_json_str = oauth_creds.to_json()
|
|
|
|
new_creds_dict = {
|
|
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
|
|
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
|
|
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
|
|
],
|
|
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
|
|
}
|
|
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
|
# SERVICE ACCOUNT
|
|
service_account_key_json_str = credentials[
|
|
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
|
]
|
|
service_account_key = json.loads(service_account_key_json_str)
|
|
|
|
service_creds = ServiceAccountCredentials.from_service_account_info(
|
|
service_account_key, scopes=GOOGLE_SCOPES[source]
|
|
)
|
|
|
|
if not service_creds.valid or not service_creds.expired:
|
|
service_creds.refresh(Request())
|
|
|
|
if not service_creds.valid:
|
|
raise PermissionError(
|
|
f"Unable to access {source} - service account credentials are invalid."
|
|
)
|
|
|
|
creds: ServiceAccountCredentials | OAuthCredentials | None = (
|
|
oauth_creds or service_creds
|
|
)
|
|
if creds is None:
|
|
raise PermissionError(
|
|
f"Unable to access {source} - unknown credential structure."
|
|
)
|
|
|
|
return creds, new_creds_dict
|