v1 refresh drive creds during perm sync (#4768)

This commit is contained in:
Evan Lohn
2025-05-23 16:01:26 -07:00
committed by GitHub
parent 3e78c2f087
commit dad99cbec7
4 changed files with 94 additions and 6 deletions

View File

@@ -1,8 +1,12 @@
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from ee.onyx.external_permissions.google_drive.models import PermissionType
from ee.onyx.external_permissions.google_drive.permission_retrieval import (
@@ -13,6 +17,7 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.google_utils.resources import RefreshableDriveObject
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
@@ -41,6 +46,20 @@ def _get_slim_doc_generator(
)
def _drive_connector_creds_getter(
google_drive_connector: GoogleDriveConnector,
) -> Callable[[], ServiceAccountCredentials | OAuthCredentials]:
def inner() -> ServiceAccountCredentials | OAuthCredentials:
if not google_drive_connector._creds_dict:
raise ValueError(
"Creds dict not found, load_credentials must be called first"
)
google_drive_connector.load_credentials(google_drive_connector._creds_dict)
return google_drive_connector.creds
return inner
def _fetch_permissions_for_permission_ids(
google_drive_connector: GoogleDriveConnector,
permission_info: dict[str, Any],
@@ -54,13 +73,22 @@ def _fetch_permissions_for_permission_ids(
if not permission_ids:
return []
drive_service = get_drive_service(
if not owner_email:
logger.warning(
f"No owner email found for document {doc_id}. Permission info: {permission_info}"
)
refreshable_drive_service = RefreshableDriveObject(
call_stack=lambda creds: get_drive_service(
creds=creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
),
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
creds_getter=_drive_connector_creds_getter(google_drive_connector),
)
return get_permissions_by_ids(
drive_service=drive_service,
drive_service=refreshable_drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
)

View File

@@ -1,14 +1,13 @@
from googleapiclient.discovery import Resource # type: ignore
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.resources import RefreshableDriveObject
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_permissions_by_ids(
drive_service: Resource,
drive_service: RefreshableDriveObject,
doc_id: str,
permission_ids: list[str],
) -> list[GoogleDrivePermission]:

View File

@@ -220,6 +220,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
self._primary_admin_email: str | None = None
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._creds_dict: dict[str, Any] | None = None
# ids of folders and shared drives that have been traversed
self._retrieved_folder_and_drive_ids: set[str] = set()
@@ -273,6 +274,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
source=DocumentSource.GOOGLE_DRIVE,
)
self._creds_dict = new_creds_dict
return new_creds_dict
def _update_traversed_parent_ids(self, folder_id: str) -> None:

View File

@@ -1,8 +1,16 @@
from collections.abc import Callable
from typing import Any
from google.auth.exceptions import RefreshError # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
from onyx.utils.logger import setup_logger
logger = setup_logger()
class GoogleDriveService(Resource):
pass
@@ -20,6 +28,56 @@ class GmailService(Resource):
pass
class RefreshableDriveObject:
"""
Running Google drive service retrieval functions
involves accessing methods of the service object (ie. files().list())
which can raise a RefreshError if the access token is expired.
This class is a wrapper that propagates the ability to refresh the access token
and retry the final retrieval function until execute() is called.
"""
def __init__(
self,
call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any],
creds: ServiceAccountCredentials | OAuthCredentials,
creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials],
):
self.call_stack = call_stack
self.creds = creds
self.creds_getter = creds_getter
def __getattr__(self, name: str) -> Any:
if name == "execute":
return self.make_refreshable_execute()
return RefreshableDriveObject(
lambda creds: getattr(self.call_stack(creds), name),
self.creds,
self.creds_getter,
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return RefreshableDriveObject(
lambda creds: self.call_stack(creds)(*args, **kwargs),
self.creds,
self.creds_getter,
)
def make_refreshable_execute(self) -> Callable:
def execute(*args: Any, **kwargs: Any) -> Any:
try:
return self.call_stack(self.creds).execute(*args, **kwargs)
except RefreshError as e:
logger.warning(
f"RefreshError, going to attempt a creds refresh and retry: {e}"
)
# Refresh the access token
self.creds = self.creds_getter()
return self.call_stack(self.creds).execute(*args, **kwargs)
return execute
def _get_google_service(
service_name: str,
service_version: str,