mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-29 21:37:21 +02:00
post rebase fixes
This commit is contained in:
@@ -1,11 +1,16 @@
|
||||
import json
|
||||
from typing import cast
|
||||
from typing import Any
|
||||
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
@@ -18,14 +23,40 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GOOGLE_SCOPES,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
|
||||
"""we really don't want to be persisting the client id and secret anywhere but the
|
||||
environment.
|
||||
|
||||
Returns a string of serialized json.
|
||||
"""
|
||||
|
||||
# strip the client id and secret
|
||||
oauth_creds_json_str = oauth_creds.to_json()
|
||||
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
|
||||
oauth_creds_sanitized_json.pop("client_id", None)
|
||||
oauth_creds_sanitized_json.pop("client_secret", None)
|
||||
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
|
||||
return oauth_creds_sanitized_json_str
|
||||
|
||||
|
||||
def get_google_oauth_creds(
|
||||
token_json_str: str, source: DocumentSource
|
||||
) -> OAuthCredentials | None:
|
||||
"""creds_json only needs to contain client_id, client_secret and refresh_token to
|
||||
refresh the creds.
|
||||
|
||||
expiry and token are optional ... however, if passing in expiry, token
|
||||
should also be passed in or else we may not return any creds.
|
||||
(probably a sign we should refactor the function)
|
||||
"""
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(
|
||||
info=creds_json,
|
||||
@@ -41,7 +72,7 @@ def get_google_oauth_creds(
|
||||
logger.notice("Refreshed Google Drive tokens.")
|
||||
return creds
|
||||
except Exception:
|
||||
logger.exception("Failed to refresh google drive access token due to:")
|
||||
logger.exception("Failed to refresh google drive access token")
|
||||
return None
|
||||
|
||||
return None
|
||||
@@ -52,31 +83,72 @@ def get_google_creds(
|
||||
source: DocumentSource,
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
"""Checks for two different types of credentials.
|
||||
(1) A credential which holds a token acquired via a user going thorough
|
||||
(1) A credential which holds a token acquired via a user going through
|
||||
the Google OAuth flow.
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
|
||||
Return a tuple where:
|
||||
The first element is the requested credentials
|
||||
The second element is a new credentials dict that the caller should write back
|
||||
to the db. This happens if token rotation occurs while loading credentials.
|
||||
"""
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
# OAUTH
|
||||
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=access_token_json_str, source=source
|
||||
authentication_method: str = credentials.get(
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
GoogleOAuthAuthenticationMethod.UPLOADED.value,
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
|
||||
],
|
||||
}
|
||||
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||
credentials_dict = json.loads(credentials_dict_str)
|
||||
|
||||
# only send what get_google_oauth_creds needs
|
||||
authorized_user_info = {}
|
||||
|
||||
# oauth_interactive is sanitized and needs credentials from the environment
|
||||
if (
|
||||
authentication_method
|
||||
== GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
):
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
else:
|
||||
authorized_user_info["client_id"] = credentials_dict["client_id"]
|
||||
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
|
||||
|
||||
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
|
||||
|
||||
authorized_user_info["token"] = credentials_dict["token"]
|
||||
authorized_user_info["expiry"] = credentials_dict["expiry"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=source
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if the refresh token changed
|
||||
if oauth_creds:
|
||||
if oauth_creds.refresh_token != authorized_user_info["refresh_token"]:
|
||||
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
|
||||
if (
|
||||
authentication_method
|
||||
== GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
):
|
||||
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
else:
|
||||
oauth_creds_json_str = oauth_creds.to_json()
|
||||
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
|
||||
],
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
|
||||
}
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
# SERVICE ACCOUNT
|
||||
service_account_key_json_str = credentials[
|
||||
|
Reference in New Issue
Block a user