From 17bc0f89ff73231b81370cc1d45c6b951f6b3e05 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sat, 13 May 2023 23:04:16 -0700 Subject: [PATCH] DAN-56 Make google drive connector production ready (#45) --- backend/danswer/configs/constants.py | 1 + .../danswer/connectors/google_drive/batch.py | 29 +---- .../connectors/google_drive/connector_auth.py | 109 ++++++++++++++++++ backend/danswer/connectors/slack/config.py | 1 - backend/danswer/server/admin.py | 53 +++++---- backend/danswer/server/models.py | 34 +++++- backend/danswer/server/search_backend.py | 4 +- backend/scripts/ingestion.py | 2 + 8 files changed, 185 insertions(+), 48 deletions(-) create mode 100644 backend/danswer/connectors/google_drive/connector_auth.py diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 61882be5936c..0822ddb388a0 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -11,6 +11,7 @@ SEMANTIC_IDENTIFIER = "semantic_identifier" SECTION_CONTINUATION = "section_continuation" ALLOWED_USERS = "allowed_users" ALLOWED_GROUPS = "allowed_groups" +NO_AUTH_USER = "FooBarUser" # TODO rework this temporary solution class DocumentSource(str, Enum): diff --git a/backend/danswer/connectors/google_drive/batch.py b/backend/danswer/connectors/google_drive/batch.py index ce373ab53d77..e9875f9b34dc 100644 --- a/backend/danswer/connectors/google_drive/batch.py +++ b/backend/danswer/connectors/google_drive/batch.py @@ -1,12 +1,10 @@ import io -import os from collections.abc import Generator -from danswer.configs.app_configs import GOOGLE_DRIVE_CREDENTIAL_JSON from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED -from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource +from danswer.connectors.google_drive.connector_auth import get_drive_tokens from danswer.connectors.interfaces import PullLoader from danswer.connectors.models import Document from danswer.connectors.models import Section @@ -30,26 +28,6 @@ LINK_KEY = "link" TYPE_KEY = "type" -def get_credentials() -> Credentials: - creds = None - if os.path.exists(GOOGLE_DRIVE_TOKENS_JSON): - creds = Credentials.from_authorized_user_file(GOOGLE_DRIVE_TOKENS_JSON, SCOPES) - - if not creds or not creds.valid: - if creds and creds.expired and creds.refresh_token: - creds.refresh(Request()) - else: - flow = InstalledAppFlow.from_client_secrets_file( - GOOGLE_DRIVE_CREDENTIAL_JSON, SCOPES - ) - creds = flow.run_local_server() - - with open(GOOGLE_DRIVE_TOKENS_JSON, "w") as token_file: - token_file.write(creds.to_json()) - - return creds - - def get_file_batches( service: discovery.Resource, include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, @@ -115,7 +93,10 @@ class BatchGoogleDriveLoader(PullLoader): ) -> None: self.batch_size = batch_size self.include_shared = include_shared - self.creds = get_credentials() + self.creds = get_drive_tokens() + + if not self.creds: + raise PermissionError("Unable to access Google Drive.") def load(self) -> Generator[list[Document], None, None]: service = discovery.build("drive", "v3", credentials=self.creds) diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py new file mode 100644 index 000000000000..bda6a3ffc409 --- /dev/null +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -0,0 +1,109 @@ +import os +from typing import Any +from urllib.parse import parse_qs +from urllib.parse import urlparse + +from danswer.configs.app_configs import GOOGLE_DRIVE_CREDENTIAL_JSON +from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON +from danswer.configs.app_configs import WEB_DOMAIN +from danswer.dynamic_configs import get_dynamic_config_store +from danswer.utils.logging import setup_logger +from google.auth.transport.requests import Request # type: ignore +from google.oauth2.credentials import Credentials # type: ignore +from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore +from googleapiclient import discovery # type: ignore + +logger = setup_logger() + +SCOPES = ["https://www.googleapis.com/auth/drive.readonly"] +FRONTEND_GOOGLE_DRIVE_REDIRECT = f"{WEB_DOMAIN}/auth/connectors/google_drive/callback" + + +def backend_get_credentials() -> Credentials: + """This approach does not work for the one-box builds""" + creds = None + if os.path.exists(GOOGLE_DRIVE_TOKENS_JSON): + creds = Credentials.from_authorized_user_file(GOOGLE_DRIVE_TOKENS_JSON, SCOPES) + + if not creds or not creds.valid: + if creds and creds.expired and creds.refresh_token: + creds.refresh(Request()) + else: + flow = InstalledAppFlow.from_client_secrets_file( + GOOGLE_DRIVE_CREDENTIAL_JSON, SCOPES + ) + creds = flow.run_local_server() + + with open(GOOGLE_DRIVE_TOKENS_JSON, "w") as token_file: + token_file.write(creds.to_json()) + + return creds + + +def get_drive_tokens(token_path: str = GOOGLE_DRIVE_TOKENS_JSON) -> Any: + if not os.path.exists(token_path): + return None + + creds = Credentials.from_authorized_user_file(token_path, SCOPES) + + if not creds: + return None + if creds.valid: + return creds + + if creds.expired and creds.refresh_token: + try: + creds.refresh(Request()) + if creds.valid: + with open(token_path, "w") as token_file: + token_file.write(creds.to_json()) + return creds + except Exception as e: + logger.exception(f"Failed to refresh google drive access token due to: {e}") + return None + return None + + +def verify_csrf(user_id: str, state: str) -> None: + csrf = get_dynamic_config_store().load(user_id) + if csrf != state: + raise PermissionError( + "State from Google Drive Connector callback does not match expected" + ) + + +def get_auth_url( + user_id: str, credentials_file: str = GOOGLE_DRIVE_CREDENTIAL_JSON +) -> str: + flow = InstalledAppFlow.from_client_secrets_file( + credentials_file, + scopes=SCOPES, + redirect_uri=FRONTEND_GOOGLE_DRIVE_REDIRECT, + ) + auth_url, _ = flow.authorization_url(prompt="consent") + + parsed_url = urlparse(auth_url) + params = parse_qs(parsed_url.query) + get_dynamic_config_store().store(user_id, params.get("state", [None])[0]) # type: ignore + return str(auth_url) + + +def save_access_tokens( + auth_code: str, + token_path: str = GOOGLE_DRIVE_TOKENS_JSON, + credentials_file: str = GOOGLE_DRIVE_CREDENTIAL_JSON, +) -> Any: + flow = InstalledAppFlow.from_client_secrets_file( + credentials_file, scopes=SCOPES, redirect_uri=FRONTEND_GOOGLE_DRIVE_REDIRECT + ) + flow.fetch_token(code=auth_code) + creds = flow.credentials + + os.makedirs(os.path.dirname(token_path), exist_ok=True) + with open(token_path, "w") as token_file: + token_file.write(creds.to_json()) + + if not get_drive_tokens(token_path): + raise PermissionError("Not able to access Google Drive.") + + return creds diff --git a/backend/danswer/connectors/slack/config.py b/backend/danswer/connectors/slack/config.py index 15e8443874c0..d54d6868d489 100644 --- a/backend/danswer/connectors/slack/config.py +++ b/backend/danswer/connectors/slack/config.py @@ -1,5 +1,4 @@ from danswer.dynamic_configs import get_dynamic_config_store -from danswer.dynamic_configs.interface import ConfigNotFoundError from pydantic import BaseModel diff --git a/backend/danswer/server/admin.py b/backend/danswer/server/admin.py index a98317cb5043..eddcef82f78c 100644 --- a/backend/danswer/server/admin.py +++ b/backend/danswer/server/admin.py @@ -1,7 +1,10 @@ -from datetime import datetime - from danswer.auth.users import current_admin_user from danswer.configs.constants import DocumentSource +from danswer.configs.constants import NO_AUTH_USER +from danswer.connectors.google_drive.connector_auth import get_auth_url +from danswer.connectors.google_drive.connector_auth import get_drive_tokens +from danswer.connectors.google_drive.connector_auth import save_access_tokens +from danswer.connectors.google_drive.connector_auth import verify_csrf from danswer.connectors.models import InputType from danswer.connectors.slack.config import get_slack_config from danswer.connectors.slack.config import SlackConfig @@ -12,17 +15,43 @@ from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus from danswer.db.models import User from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.server.models import AuthStatus +from danswer.server.models import AuthUrl +from danswer.server.models import GDriveCallback +from danswer.server.models import IndexAttemptSnapshot +from danswer.server.models import ListWebsiteIndexAttemptsResponse +from danswer.server.models import WebIndexAttemptRequest from danswer.utils.logging import setup_logger from fastapi import APIRouter from fastapi import Depends -from pydantic import BaseModel - router = APIRouter(prefix="/admin") logger = setup_logger() +@router.get("/connectors/google-drive/check-auth", response_model=AuthStatus) +def check_drive_tokens(_: User = Depends(current_admin_user)) -> AuthStatus: + tokens = get_drive_tokens() + authenticated = tokens is not None + return AuthStatus(authenticated=authenticated) + + +@router.get("/connectors/google-drive/authorize", response_model=AuthUrl) +def google_drive_auth(user: User = Depends(current_admin_user)) -> AuthUrl: + user_id = str(user.id) if user else NO_AUTH_USER + return AuthUrl(auth_url=get_auth_url(user_id)) + + +@router.get("/connectors/google-drive/callback", status_code=201) +def google_drive_callback( + callback: GDriveCallback = Depends(), user: User = Depends(current_admin_user) +) -> None: + user_id = str(user.id) if user else NO_AUTH_USER + verify_csrf(user_id, callback.state) + return save_access_tokens(callback.code) + + @router.get("/connectors/slack/config", response_model=SlackConfig) def fetch_slack_config(_: User = Depends(current_admin_user)) -> SlackConfig: try: @@ -38,10 +67,6 @@ def modify_slack_config( update_slack_config(slack_config) -class WebIndexAttemptRequest(BaseModel): - url: str - - @router.post("/connectors/web/index-attempt", status_code=201) def index_website( web_index_attempt_request: WebIndexAttemptRequest, @@ -56,18 +81,6 @@ def index_website( insert_index_attempt(index_request) -class IndexAttemptSnapshot(BaseModel): - url: str - status: IndexingStatus - time_created: datetime - time_updated: datetime - docs_indexed: int - - -class ListWebsiteIndexAttemptsResponse(BaseModel): - index_attempts: list[IndexAttemptSnapshot] - - @router.get("/connectors/web/index-attempt") def list_website_index_attempts( _: User = Depends(current_admin_user), diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index ceebdfefc0fd..530ed995ceb7 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -1,13 +1,29 @@ +from datetime import datetime + from danswer.datastores.interfaces import DatastoreFilter +from danswer.db.models import IndexingStatus from pydantic import BaseModel +class AuthStatus(BaseModel): + authenticated: bool + + +class AuthUrl(BaseModel): + auth_url: str + + +class GDriveCallback(BaseModel): + state: str + code: str + + class UserRoleResponse(BaseModel): role: str class SearchDoc(BaseModel): - semantic_name: str + semantic_identifier: str link: str blurb: str source_type: str @@ -31,3 +47,19 @@ class KeywordResponse(BaseModel): class UserByEmail(BaseModel): user_email: str + + +class WebIndexAttemptRequest(BaseModel): + url: str + + +class IndexAttemptSnapshot(BaseModel): + url: str + status: IndexingStatus + time_created: datetime + time_updated: datetime + docs_indexed: int + + +class ListWebsiteIndexAttemptsResponse(BaseModel): + index_attempts: list[IndexAttemptSnapshot] diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index f017f2f43726..c4aaf3a7c13c 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -77,7 +77,7 @@ def direct_qa( top_docs = [ SearchDoc( - semantic_name=chunk.semantic_identifier, + semantic_identifier=chunk.semantic_identifier, link=chunk.source_links.get("0") if chunk.source_links else None, blurb=chunk.blurb, source_type=chunk.source_type, @@ -116,7 +116,7 @@ def stream_direct_qa( top_docs = [ SearchDoc( - semantic_name=chunk.semantic_identifier, + semantic_identifier=chunk.semantic_identifier, link=chunk.source_links.get("0") if chunk.source_links else None, blurb=chunk.blurb, source_type=chunk.source_type, diff --git a/backend/scripts/ingestion.py b/backend/scripts/ingestion.py index 231fe941388b..6de7ce830727 100644 --- a/backend/scripts/ingestion.py +++ b/backend/scripts/ingestion.py @@ -7,6 +7,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION from danswer.connectors.github.batch import BatchGithubLoader from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader +from danswer.connectors.google_drive.connector_auth import backend_get_credentials from danswer.connectors.interfaces import PullLoader from danswer.connectors.slack.batch import BatchSlackLoader from danswer.connectors.web.pull import WebLoader @@ -71,6 +72,7 @@ def load_web_batch(url: str, qdrant_collection: str) -> None: def load_google_drive_batch(qdrant_collection: str) -> None: logger.info("Loading documents from Google Drive.") + backend_get_credentials() load_batch( BatchGoogleDriveLoader(batch_size=INDEX_BATCH_SIZE), DefaultChunker(),