mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
DAN-56 Make google drive connector production ready (#45)
This commit is contained in:
@@ -11,6 +11,7 @@ SEMANTIC_IDENTIFIER = "semantic_identifier"
|
||||
SECTION_CONTINUATION = "section_continuation"
|
||||
ALLOWED_USERS = "allowed_users"
|
||||
ALLOWED_GROUPS = "allowed_groups"
|
||||
NO_AUTH_USER = "FooBarUser" # TODO rework this temporary solution
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
|
@@ -1,12 +1,10 @@
|
||||
import io
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_CREDENTIAL_JSON
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
|
||||
from danswer.connectors.interfaces import PullLoader
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
@@ -30,26 +28,6 @@ LINK_KEY = "link"
|
||||
TYPE_KEY = "type"
|
||||
|
||||
|
||||
def get_credentials() -> Credentials:
|
||||
creds = None
|
||||
if os.path.exists(GOOGLE_DRIVE_TOKENS_JSON):
|
||||
creds = Credentials.from_authorized_user_file(GOOGLE_DRIVE_TOKENS_JSON, SCOPES)
|
||||
|
||||
if not creds or not creds.valid:
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
creds.refresh(Request())
|
||||
else:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(
|
||||
GOOGLE_DRIVE_CREDENTIAL_JSON, SCOPES
|
||||
)
|
||||
creds = flow.run_local_server()
|
||||
|
||||
with open(GOOGLE_DRIVE_TOKENS_JSON, "w") as token_file:
|
||||
token_file.write(creds.to_json())
|
||||
|
||||
return creds
|
||||
|
||||
|
||||
def get_file_batches(
|
||||
service: discovery.Resource,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
@@ -115,7 +93,10 @@ class BatchGoogleDriveLoader(PullLoader):
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.include_shared = include_shared
|
||||
self.creds = get_credentials()
|
||||
self.creds = get_drive_tokens()
|
||||
|
||||
if not self.creds:
|
||||
raise PermissionError("Unable to access Google Drive.")
|
||||
|
||||
def load(self) -> Generator[list[Document], None, None]:
|
||||
service = discovery.build("drive", "v3", credentials=self.creds)
|
||||
|
109
backend/danswer/connectors/google_drive/connector_auth.py
Normal file
109
backend/danswer/connectors/google_drive/connector_auth.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_CREDENTIAL_JSON
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.utils.logging import setup_logger
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials # type: ignore
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from googleapiclient import discovery # type: ignore
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||
FRONTEND_GOOGLE_DRIVE_REDIRECT = f"{WEB_DOMAIN}/auth/connectors/google_drive/callback"
|
||||
|
||||
|
||||
def backend_get_credentials() -> Credentials:
|
||||
"""This approach does not work for the one-box builds"""
|
||||
creds = None
|
||||
if os.path.exists(GOOGLE_DRIVE_TOKENS_JSON):
|
||||
creds = Credentials.from_authorized_user_file(GOOGLE_DRIVE_TOKENS_JSON, SCOPES)
|
||||
|
||||
if not creds or not creds.valid:
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
creds.refresh(Request())
|
||||
else:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(
|
||||
GOOGLE_DRIVE_CREDENTIAL_JSON, SCOPES
|
||||
)
|
||||
creds = flow.run_local_server()
|
||||
|
||||
with open(GOOGLE_DRIVE_TOKENS_JSON, "w") as token_file:
|
||||
token_file.write(creds.to_json())
|
||||
|
||||
return creds
|
||||
|
||||
|
||||
def get_drive_tokens(token_path: str = GOOGLE_DRIVE_TOKENS_JSON) -> Any:
|
||||
if not os.path.exists(token_path):
|
||||
return None
|
||||
|
||||
creds = Credentials.from_authorized_user_file(token_path, SCOPES)
|
||||
|
||||
if not creds:
|
||||
return None
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
if creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
if creds.valid:
|
||||
with open(token_path, "w") as token_file:
|
||||
token_file.write(creds.to_json())
|
||||
return creds
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to refresh google drive access token due to: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def verify_csrf(user_id: str, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(user_id)
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
)
|
||||
|
||||
|
||||
def get_auth_url(
|
||||
user_id: str, credentials_file: str = GOOGLE_DRIVE_CREDENTIAL_JSON
|
||||
) -> str:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(
|
||||
credentials_file,
|
||||
scopes=SCOPES,
|
||||
redirect_uri=FRONTEND_GOOGLE_DRIVE_REDIRECT,
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
|
||||
parsed_url = urlparse(auth_url)
|
||||
params = parse_qs(parsed_url.query)
|
||||
get_dynamic_config_store().store(user_id, params.get("state", [None])[0]) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def save_access_tokens(
|
||||
auth_code: str,
|
||||
token_path: str = GOOGLE_DRIVE_TOKENS_JSON,
|
||||
credentials_file: str = GOOGLE_DRIVE_CREDENTIAL_JSON,
|
||||
) -> Any:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(
|
||||
credentials_file, scopes=SCOPES, redirect_uri=FRONTEND_GOOGLE_DRIVE_REDIRECT
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
creds = flow.credentials
|
||||
|
||||
os.makedirs(os.path.dirname(token_path), exist_ok=True)
|
||||
with open(token_path, "w") as token_file:
|
||||
token_file.write(creds.to_json())
|
||||
|
||||
if not get_drive_tokens(token_path):
|
||||
raise PermissionError("Not able to access Google Drive.")
|
||||
|
||||
return creds
|
@@ -1,5 +1,4 @@
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
@@ -1,7 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import NO_AUTH_USER
|
||||
from danswer.connectors.google_drive.connector_auth import get_auth_url
|
||||
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
|
||||
from danswer.connectors.google_drive.connector_auth import save_access_tokens
|
||||
from danswer.connectors.google_drive.connector_auth import verify_csrf
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.slack.config import get_slack_config
|
||||
from danswer.connectors.slack.config import SlackConfig
|
||||
@@ -12,17 +15,43 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.server.models import AuthStatus
|
||||
from danswer.server.models import AuthUrl
|
||||
from danswer.server.models import GDriveCallback
|
||||
from danswer.server.models import IndexAttemptSnapshot
|
||||
from danswer.server.models import ListWebsiteIndexAttemptsResponse
|
||||
from danswer.server.models import WebIndexAttemptRequest
|
||||
from danswer.utils.logging import setup_logger
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
router = APIRouter(prefix="/admin")
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@router.get("/connectors/google-drive/check-auth", response_model=AuthStatus)
|
||||
def check_drive_tokens(_: User = Depends(current_admin_user)) -> AuthStatus:
|
||||
tokens = get_drive_tokens()
|
||||
authenticated = tokens is not None
|
||||
return AuthStatus(authenticated=authenticated)
|
||||
|
||||
|
||||
@router.get("/connectors/google-drive/authorize", response_model=AuthUrl)
|
||||
def google_drive_auth(user: User = Depends(current_admin_user)) -> AuthUrl:
|
||||
user_id = str(user.id) if user else NO_AUTH_USER
|
||||
return AuthUrl(auth_url=get_auth_url(user_id))
|
||||
|
||||
|
||||
@router.get("/connectors/google-drive/callback", status_code=201)
|
||||
def google_drive_callback(
|
||||
callback: GDriveCallback = Depends(), user: User = Depends(current_admin_user)
|
||||
) -> None:
|
||||
user_id = str(user.id) if user else NO_AUTH_USER
|
||||
verify_csrf(user_id, callback.state)
|
||||
return save_access_tokens(callback.code)
|
||||
|
||||
|
||||
@router.get("/connectors/slack/config", response_model=SlackConfig)
|
||||
def fetch_slack_config(_: User = Depends(current_admin_user)) -> SlackConfig:
|
||||
try:
|
||||
@@ -38,10 +67,6 @@ def modify_slack_config(
|
||||
update_slack_config(slack_config)
|
||||
|
||||
|
||||
class WebIndexAttemptRequest(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/connectors/web/index-attempt", status_code=201)
|
||||
def index_website(
|
||||
web_index_attempt_request: WebIndexAttemptRequest,
|
||||
@@ -56,18 +81,6 @@ def index_website(
|
||||
insert_index_attempt(index_request)
|
||||
|
||||
|
||||
class IndexAttemptSnapshot(BaseModel):
|
||||
url: str
|
||||
status: IndexingStatus
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
docs_indexed: int
|
||||
|
||||
|
||||
class ListWebsiteIndexAttemptsResponse(BaseModel):
|
||||
index_attempts: list[IndexAttemptSnapshot]
|
||||
|
||||
|
||||
@router.get("/connectors/web/index-attempt")
|
||||
def list_website_index_attempts(
|
||||
_: User = Depends(current_admin_user),
|
||||
|
@@ -1,13 +1,29 @@
|
||||
from datetime import datetime
|
||||
|
||||
from danswer.datastores.interfaces import DatastoreFilter
|
||||
from danswer.db.models import IndexingStatus
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthStatus(BaseModel):
|
||||
authenticated: bool
|
||||
|
||||
|
||||
class AuthUrl(BaseModel):
|
||||
auth_url: str
|
||||
|
||||
|
||||
class GDriveCallback(BaseModel):
|
||||
state: str
|
||||
code: str
|
||||
|
||||
|
||||
class UserRoleResponse(BaseModel):
|
||||
role: str
|
||||
|
||||
|
||||
class SearchDoc(BaseModel):
|
||||
semantic_name: str
|
||||
semantic_identifier: str
|
||||
link: str
|
||||
blurb: str
|
||||
source_type: str
|
||||
@@ -31,3 +47,19 @@ class KeywordResponse(BaseModel):
|
||||
|
||||
class UserByEmail(BaseModel):
|
||||
user_email: str
|
||||
|
||||
|
||||
class WebIndexAttemptRequest(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class IndexAttemptSnapshot(BaseModel):
|
||||
url: str
|
||||
status: IndexingStatus
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
docs_indexed: int
|
||||
|
||||
|
||||
class ListWebsiteIndexAttemptsResponse(BaseModel):
|
||||
index_attempts: list[IndexAttemptSnapshot]
|
||||
|
@@ -77,7 +77,7 @@ def direct_qa(
|
||||
|
||||
top_docs = [
|
||||
SearchDoc(
|
||||
semantic_name=chunk.semantic_identifier,
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
link=chunk.source_links.get("0") if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
@@ -116,7 +116,7 @@ def stream_direct_qa(
|
||||
|
||||
top_docs = [
|
||||
SearchDoc(
|
||||
semantic_name=chunk.semantic_identifier,
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
link=chunk.source_links.get("0") if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
|
@@ -7,6 +7,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.connectors.github.batch import BatchGithubLoader
|
||||
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
|
||||
from danswer.connectors.google_drive.connector_auth import backend_get_credentials
|
||||
from danswer.connectors.interfaces import PullLoader
|
||||
from danswer.connectors.slack.batch import BatchSlackLoader
|
||||
from danswer.connectors.web.pull import WebLoader
|
||||
@@ -71,6 +72,7 @@ def load_web_batch(url: str, qdrant_collection: str) -> None:
|
||||
|
||||
def load_google_drive_batch(qdrant_collection: str) -> None:
|
||||
logger.info("Loading documents from Google Drive.")
|
||||
backend_get_credentials()
|
||||
load_batch(
|
||||
BatchGoogleDriveLoader(batch_size=INDEX_BATCH_SIZE),
|
||||
DefaultChunker(),
|
||||
|
Reference in New Issue
Block a user