DAN-56 Make google drive connector production ready (#45)

This commit is contained in:
Yuhong Sun
2023-05-13 23:04:16 -07:00
committed by GitHub
parent b2cde3e4bb
commit 17bc0f89ff
8 changed files with 185 additions and 48 deletions

View File

@@ -11,6 +11,7 @@ SEMANTIC_IDENTIFIER = "semantic_identifier"
SECTION_CONTINUATION = "section_continuation"
ALLOWED_USERS = "allowed_users"
ALLOWED_GROUPS = "allowed_groups"
NO_AUTH_USER = "FooBarUser" # TODO rework this temporary solution
class DocumentSource(str, Enum):

View File

@@ -1,12 +1,10 @@
import io
import os
from collections.abc import Generator
from danswer.configs.app_configs import GOOGLE_DRIVE_CREDENTIAL_JSON
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
from danswer.connectors.interfaces import PullLoader
from danswer.connectors.models import Document
from danswer.connectors.models import Section
@@ -30,26 +28,6 @@ LINK_KEY = "link"
TYPE_KEY = "type"
def get_credentials() -> Credentials:
creds = None
if os.path.exists(GOOGLE_DRIVE_TOKENS_JSON):
creds = Credentials.from_authorized_user_file(GOOGLE_DRIVE_TOKENS_JSON, SCOPES)
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file(
GOOGLE_DRIVE_CREDENTIAL_JSON, SCOPES
)
creds = flow.run_local_server()
with open(GOOGLE_DRIVE_TOKENS_JSON, "w") as token_file:
token_file.write(creds.to_json())
return creds
def get_file_batches(
service: discovery.Resource,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
@@ -115,7 +93,10 @@ class BatchGoogleDriveLoader(PullLoader):
) -> None:
self.batch_size = batch_size
self.include_shared = include_shared
self.creds = get_credentials()
self.creds = get_drive_tokens()
if not self.creds:
raise PermissionError("Unable to access Google Drive.")
def load(self) -> Generator[list[Document], None, None]:
service = discovery.build("drive", "v3", credentials=self.creds)

View File

@@ -0,0 +1,109 @@
import os
from typing import Any
from urllib.parse import parse_qs
from urllib.parse import urlparse
from danswer.configs.app_configs import GOOGLE_DRIVE_CREDENTIAL_JSON
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.utils.logging import setup_logger
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from googleapiclient import discovery # type: ignore
logger = setup_logger()
SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
FRONTEND_GOOGLE_DRIVE_REDIRECT = f"{WEB_DOMAIN}/auth/connectors/google_drive/callback"
def backend_get_credentials() -> Credentials:
"""This approach does not work for the one-box builds"""
creds = None
if os.path.exists(GOOGLE_DRIVE_TOKENS_JSON):
creds = Credentials.from_authorized_user_file(GOOGLE_DRIVE_TOKENS_JSON, SCOPES)
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file(
GOOGLE_DRIVE_CREDENTIAL_JSON, SCOPES
)
creds = flow.run_local_server()
with open(GOOGLE_DRIVE_TOKENS_JSON, "w") as token_file:
token_file.write(creds.to_json())
return creds
def get_drive_tokens(token_path: str = GOOGLE_DRIVE_TOKENS_JSON) -> Any:
if not os.path.exists(token_path):
return None
creds = Credentials.from_authorized_user_file(token_path, SCOPES)
if not creds:
return None
if creds.valid:
return creds
if creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
with open(token_path, "w") as token_file:
token_file.write(creds.to_json())
return creds
except Exception as e:
logger.exception(f"Failed to refresh google drive access token due to: {e}")
return None
return None
def verify_csrf(user_id: str, state: str) -> None:
csrf = get_dynamic_config_store().load(user_id)
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
)
def get_auth_url(
user_id: str, credentials_file: str = GOOGLE_DRIVE_CREDENTIAL_JSON
) -> str:
flow = InstalledAppFlow.from_client_secrets_file(
credentials_file,
scopes=SCOPES,
redirect_uri=FRONTEND_GOOGLE_DRIVE_REDIRECT,
)
auth_url, _ = flow.authorization_url(prompt="consent")
parsed_url = urlparse(auth_url)
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(user_id, params.get("state", [None])[0]) # type: ignore
return str(auth_url)
def save_access_tokens(
auth_code: str,
token_path: str = GOOGLE_DRIVE_TOKENS_JSON,
credentials_file: str = GOOGLE_DRIVE_CREDENTIAL_JSON,
) -> Any:
flow = InstalledAppFlow.from_client_secrets_file(
credentials_file, scopes=SCOPES, redirect_uri=FRONTEND_GOOGLE_DRIVE_REDIRECT
)
flow.fetch_token(code=auth_code)
creds = flow.credentials
os.makedirs(os.path.dirname(token_path), exist_ok=True)
with open(token_path, "w") as token_file:
token_file.write(creds.to_json())
if not get_drive_tokens(token_path):
raise PermissionError("Not able to access Google Drive.")
return creds

View File

@@ -1,5 +1,4 @@
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from pydantic import BaseModel

View File

@@ -1,7 +1,10 @@
from datetime import datetime
from danswer.auth.users import current_admin_user
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import NO_AUTH_USER
from danswer.connectors.google_drive.connector_auth import get_auth_url
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
from danswer.connectors.google_drive.connector_auth import save_access_tokens
from danswer.connectors.google_drive.connector_auth import verify_csrf
from danswer.connectors.models import InputType
from danswer.connectors.slack.config import get_slack_config
from danswer.connectors.slack.config import SlackConfig
@@ -12,17 +15,43 @@ from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import User
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.models import AuthStatus
from danswer.server.models import AuthUrl
from danswer.server.models import GDriveCallback
from danswer.server.models import IndexAttemptSnapshot
from danswer.server.models import ListWebsiteIndexAttemptsResponse
from danswer.server.models import WebIndexAttemptRequest
from danswer.utils.logging import setup_logger
from fastapi import APIRouter
from fastapi import Depends
from pydantic import BaseModel
router = APIRouter(prefix="/admin")
logger = setup_logger()
@router.get("/connectors/google-drive/check-auth", response_model=AuthStatus)
def check_drive_tokens(_: User = Depends(current_admin_user)) -> AuthStatus:
tokens = get_drive_tokens()
authenticated = tokens is not None
return AuthStatus(authenticated=authenticated)
@router.get("/connectors/google-drive/authorize", response_model=AuthUrl)
def google_drive_auth(user: User = Depends(current_admin_user)) -> AuthUrl:
user_id = str(user.id) if user else NO_AUTH_USER
return AuthUrl(auth_url=get_auth_url(user_id))
@router.get("/connectors/google-drive/callback", status_code=201)
def google_drive_callback(
callback: GDriveCallback = Depends(), user: User = Depends(current_admin_user)
) -> None:
user_id = str(user.id) if user else NO_AUTH_USER
verify_csrf(user_id, callback.state)
return save_access_tokens(callback.code)
@router.get("/connectors/slack/config", response_model=SlackConfig)
def fetch_slack_config(_: User = Depends(current_admin_user)) -> SlackConfig:
try:
@@ -38,10 +67,6 @@ def modify_slack_config(
update_slack_config(slack_config)
class WebIndexAttemptRequest(BaseModel):
url: str
@router.post("/connectors/web/index-attempt", status_code=201)
def index_website(
web_index_attempt_request: WebIndexAttemptRequest,
@@ -56,18 +81,6 @@ def index_website(
insert_index_attempt(index_request)
class IndexAttemptSnapshot(BaseModel):
url: str
status: IndexingStatus
time_created: datetime
time_updated: datetime
docs_indexed: int
class ListWebsiteIndexAttemptsResponse(BaseModel):
index_attempts: list[IndexAttemptSnapshot]
@router.get("/connectors/web/index-attempt")
def list_website_index_attempts(
_: User = Depends(current_admin_user),

View File

@@ -1,13 +1,29 @@
from datetime import datetime
from danswer.datastores.interfaces import DatastoreFilter
from danswer.db.models import IndexingStatus
from pydantic import BaseModel
class AuthStatus(BaseModel):
authenticated: bool
class AuthUrl(BaseModel):
auth_url: str
class GDriveCallback(BaseModel):
state: str
code: str
class UserRoleResponse(BaseModel):
role: str
class SearchDoc(BaseModel):
semantic_name: str
semantic_identifier: str
link: str
blurb: str
source_type: str
@@ -31,3 +47,19 @@ class KeywordResponse(BaseModel):
class UserByEmail(BaseModel):
user_email: str
class WebIndexAttemptRequest(BaseModel):
url: str
class IndexAttemptSnapshot(BaseModel):
url: str
status: IndexingStatus
time_created: datetime
time_updated: datetime
docs_indexed: int
class ListWebsiteIndexAttemptsResponse(BaseModel):
index_attempts: list[IndexAttemptSnapshot]

View File

@@ -77,7 +77,7 @@ def direct_qa(
top_docs = [
SearchDoc(
semantic_name=chunk.semantic_identifier,
semantic_identifier=chunk.semantic_identifier,
link=chunk.source_links.get("0") if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
@@ -116,7 +116,7 @@ def stream_direct_qa(
top_docs = [
SearchDoc(
semantic_name=chunk.semantic_identifier,
semantic_identifier=chunk.semantic_identifier,
link=chunk.source_links.get("0") if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,

View File

@@ -7,6 +7,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.connectors.github.batch import BatchGithubLoader
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
from danswer.connectors.google_drive.connector_auth import backend_get_credentials
from danswer.connectors.interfaces import PullLoader
from danswer.connectors.slack.batch import BatchSlackLoader
from danswer.connectors.web.pull import WebLoader
@@ -71,6 +72,7 @@ def load_web_batch(url: str, qdrant_collection: str) -> None:
def load_google_drive_batch(qdrant_collection: str) -> None:
logger.info("Loading documents from Google Drive.")
backend_get_credentials()
load_batch(
BatchGoogleDriveLoader(batch_size=INDEX_BATCH_SIZE),
DefaultChunker(),