mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-29 07:04:16 +02:00
v1 refresh drive creds during perm sync (#4768)
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from ee.onyx.external_permissions.google_drive.models import PermissionType
|
||||
from ee.onyx.external_permissions.google_drive.permission_retrieval import (
|
||||
@@ -13,6 +17,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.google_utils.resources import RefreshableDriveObject
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -41,6 +46,20 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
|
||||
def _drive_connector_creds_getter(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
) -> Callable[[], ServiceAccountCredentials | OAuthCredentials]:
|
||||
def inner() -> ServiceAccountCredentials | OAuthCredentials:
|
||||
if not google_drive_connector._creds_dict:
|
||||
raise ValueError(
|
||||
"Creds dict not found, load_credentials must be called first"
|
||||
)
|
||||
google_drive_connector.load_credentials(google_drive_connector._creds_dict)
|
||||
return google_drive_connector.creds
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _fetch_permissions_for_permission_ids(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
permission_info: dict[str, Any],
|
||||
@@ -54,13 +73,22 @@ def _fetch_permissions_for_permission_ids(
|
||||
if not permission_ids:
|
||||
return []
|
||||
|
||||
drive_service = get_drive_service(
|
||||
if not owner_email:
|
||||
logger.warning(
|
||||
f"No owner email found for document {doc_id}. Permission info: {permission_info}"
|
||||
)
|
||||
|
||||
refreshable_drive_service = RefreshableDriveObject(
|
||||
call_stack=lambda creds: get_drive_service(
|
||||
creds=creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
),
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
creds_getter=_drive_connector_creds_getter(google_drive_connector),
|
||||
)
|
||||
|
||||
return get_permissions_by_ids(
|
||||
drive_service=drive_service,
|
||||
drive_service=refreshable_drive_service,
|
||||
doc_id=doc_id,
|
||||
permission_ids=permission_ids,
|
||||
)
|
||||
|
@@ -1,14 +1,13 @@
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.resources import RefreshableDriveObject
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_permissions_by_ids(
|
||||
drive_service: Resource,
|
||||
drive_service: RefreshableDriveObject,
|
||||
doc_id: str,
|
||||
permission_ids: list[str],
|
||||
) -> list[GoogleDrivePermission]:
|
||||
|
@@ -220,6 +220,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
self._primary_admin_email: str | None = None
|
||||
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
self._creds_dict: dict[str, Any] | None = None
|
||||
|
||||
# ids of folders and shared drives that have been traversed
|
||||
self._retrieved_folder_and_drive_ids: set[str] = set()
|
||||
@@ -273,6 +274,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
|
||||
self._creds_dict = new_creds_dict
|
||||
|
||||
return new_creds_dict
|
||||
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
|
@@ -1,8 +1,16 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from google.auth.exceptions import RefreshError # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class GoogleDriveService(Resource):
|
||||
pass
|
||||
@@ -20,6 +28,56 @@ class GmailService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class RefreshableDriveObject:
|
||||
"""
|
||||
Running Google drive service retrieval functions
|
||||
involves accessing methods of the service object (ie. files().list())
|
||||
which can raise a RefreshError if the access token is expired.
|
||||
This class is a wrapper that propagates the ability to refresh the access token
|
||||
and retry the final retrieval function until execute() is called.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any],
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials],
|
||||
):
|
||||
self.call_stack = call_stack
|
||||
self.creds = creds
|
||||
self.creds_getter = creds_getter
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name == "execute":
|
||||
return self.make_refreshable_execute()
|
||||
return RefreshableDriveObject(
|
||||
lambda creds: getattr(self.call_stack(creds), name),
|
||||
self.creds,
|
||||
self.creds_getter,
|
||||
)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return RefreshableDriveObject(
|
||||
lambda creds: self.call_stack(creds)(*args, **kwargs),
|
||||
self.creds,
|
||||
self.creds_getter,
|
||||
)
|
||||
|
||||
def make_refreshable_execute(self) -> Callable:
|
||||
def execute(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||
except RefreshError as e:
|
||||
logger.warning(
|
||||
f"RefreshError, going to attempt a creds refresh and retry: {e}"
|
||||
)
|
||||
# Refresh the access token
|
||||
self.creds = self.creds_getter()
|
||||
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _get_google_service(
|
||||
service_name: str,
|
||||
service_version: str,
|
||||
|
Reference in New Issue
Block a user