mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 13:05:49 +02:00
DAN-56 Make google drive connector production ready (#45)
This commit is contained in:
@@ -11,6 +11,7 @@ SEMANTIC_IDENTIFIER = "semantic_identifier"
|
|||||||
SECTION_CONTINUATION = "section_continuation"
|
SECTION_CONTINUATION = "section_continuation"
|
||||||
ALLOWED_USERS = "allowed_users"
|
ALLOWED_USERS = "allowed_users"
|
||||||
ALLOWED_GROUPS = "allowed_groups"
|
ALLOWED_GROUPS = "allowed_groups"
|
||||||
|
NO_AUTH_USER = "FooBarUser" # TODO rework this temporary solution
|
||||||
|
|
||||||
|
|
||||||
class DocumentSource(str, Enum):
|
class DocumentSource(str, Enum):
|
||||||
|
@@ -1,12 +1,10 @@
|
|||||||
import io
|
import io
|
||||||
import os
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from danswer.configs.app_configs import GOOGLE_DRIVE_CREDENTIAL_JSON
|
|
||||||
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
|
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
|
||||||
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
|
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
|
||||||
from danswer.connectors.interfaces import PullLoader
|
from danswer.connectors.interfaces import PullLoader
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
@@ -30,26 +28,6 @@ LINK_KEY = "link"
|
|||||||
TYPE_KEY = "type"
|
TYPE_KEY = "type"
|
||||||
|
|
||||||
|
|
||||||
def get_credentials() -> Credentials:
|
|
||||||
creds = None
|
|
||||||
if os.path.exists(GOOGLE_DRIVE_TOKENS_JSON):
|
|
||||||
creds = Credentials.from_authorized_user_file(GOOGLE_DRIVE_TOKENS_JSON, SCOPES)
|
|
||||||
|
|
||||||
if not creds or not creds.valid:
|
|
||||||
if creds and creds.expired and creds.refresh_token:
|
|
||||||
creds.refresh(Request())
|
|
||||||
else:
|
|
||||||
flow = InstalledAppFlow.from_client_secrets_file(
|
|
||||||
GOOGLE_DRIVE_CREDENTIAL_JSON, SCOPES
|
|
||||||
)
|
|
||||||
creds = flow.run_local_server()
|
|
||||||
|
|
||||||
with open(GOOGLE_DRIVE_TOKENS_JSON, "w") as token_file:
|
|
||||||
token_file.write(creds.to_json())
|
|
||||||
|
|
||||||
return creds
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_batches(
|
def get_file_batches(
|
||||||
service: discovery.Resource,
|
service: discovery.Resource,
|
||||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||||
@@ -115,7 +93,10 @@ class BatchGoogleDriveLoader(PullLoader):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.include_shared = include_shared
|
self.include_shared = include_shared
|
||||||
self.creds = get_credentials()
|
self.creds = get_drive_tokens()
|
||||||
|
|
||||||
|
if not self.creds:
|
||||||
|
raise PermissionError("Unable to access Google Drive.")
|
||||||
|
|
||||||
def load(self) -> Generator[list[Document], None, None]:
|
def load(self) -> Generator[list[Document], None, None]:
|
||||||
service = discovery.build("drive", "v3", credentials=self.creds)
|
service = discovery.build("drive", "v3", credentials=self.creds)
|
||||||
|
109
backend/danswer/connectors/google_drive/connector_auth.py
Normal file
109
backend/danswer/connectors/google_drive/connector_auth.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import parse_qs
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import GOOGLE_DRIVE_CREDENTIAL_JSON
|
||||||
|
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
|
||||||
|
from danswer.configs.app_configs import WEB_DOMAIN
|
||||||
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
|
from danswer.utils.logging import setup_logger
|
||||||
|
from google.auth.transport.requests import Request # type: ignore
|
||||||
|
from google.oauth2.credentials import Credentials # type: ignore
|
||||||
|
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||||
|
from googleapiclient import discovery # type: ignore
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||||
|
FRONTEND_GOOGLE_DRIVE_REDIRECT = f"{WEB_DOMAIN}/auth/connectors/google_drive/callback"
|
||||||
|
|
||||||
|
|
||||||
|
def backend_get_credentials() -> Credentials:
|
||||||
|
"""This approach does not work for the one-box builds"""
|
||||||
|
creds = None
|
||||||
|
if os.path.exists(GOOGLE_DRIVE_TOKENS_JSON):
|
||||||
|
creds = Credentials.from_authorized_user_file(GOOGLE_DRIVE_TOKENS_JSON, SCOPES)
|
||||||
|
|
||||||
|
if not creds or not creds.valid:
|
||||||
|
if creds and creds.expired and creds.refresh_token:
|
||||||
|
creds.refresh(Request())
|
||||||
|
else:
|
||||||
|
flow = InstalledAppFlow.from_client_secrets_file(
|
||||||
|
GOOGLE_DRIVE_CREDENTIAL_JSON, SCOPES
|
||||||
|
)
|
||||||
|
creds = flow.run_local_server()
|
||||||
|
|
||||||
|
with open(GOOGLE_DRIVE_TOKENS_JSON, "w") as token_file:
|
||||||
|
token_file.write(creds.to_json())
|
||||||
|
|
||||||
|
return creds
|
||||||
|
|
||||||
|
|
||||||
|
def get_drive_tokens(token_path: str = GOOGLE_DRIVE_TOKENS_JSON) -> Any:
|
||||||
|
if not os.path.exists(token_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
creds = Credentials.from_authorized_user_file(token_path, SCOPES)
|
||||||
|
|
||||||
|
if not creds:
|
||||||
|
return None
|
||||||
|
if creds.valid:
|
||||||
|
return creds
|
||||||
|
|
||||||
|
if creds.expired and creds.refresh_token:
|
||||||
|
try:
|
||||||
|
creds.refresh(Request())
|
||||||
|
if creds.valid:
|
||||||
|
with open(token_path, "w") as token_file:
|
||||||
|
token_file.write(creds.to_json())
|
||||||
|
return creds
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to refresh google drive access token due to: {e}")
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def verify_csrf(user_id: str, state: str) -> None:
|
||||||
|
csrf = get_dynamic_config_store().load(user_id)
|
||||||
|
if csrf != state:
|
||||||
|
raise PermissionError(
|
||||||
|
"State from Google Drive Connector callback does not match expected"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_url(
|
||||||
|
user_id: str, credentials_file: str = GOOGLE_DRIVE_CREDENTIAL_JSON
|
||||||
|
) -> str:
|
||||||
|
flow = InstalledAppFlow.from_client_secrets_file(
|
||||||
|
credentials_file,
|
||||||
|
scopes=SCOPES,
|
||||||
|
redirect_uri=FRONTEND_GOOGLE_DRIVE_REDIRECT,
|
||||||
|
)
|
||||||
|
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||||
|
|
||||||
|
parsed_url = urlparse(auth_url)
|
||||||
|
params = parse_qs(parsed_url.query)
|
||||||
|
get_dynamic_config_store().store(user_id, params.get("state", [None])[0]) # type: ignore
|
||||||
|
return str(auth_url)
|
||||||
|
|
||||||
|
|
||||||
|
def save_access_tokens(
|
||||||
|
auth_code: str,
|
||||||
|
token_path: str = GOOGLE_DRIVE_TOKENS_JSON,
|
||||||
|
credentials_file: str = GOOGLE_DRIVE_CREDENTIAL_JSON,
|
||||||
|
) -> Any:
|
||||||
|
flow = InstalledAppFlow.from_client_secrets_file(
|
||||||
|
credentials_file, scopes=SCOPES, redirect_uri=FRONTEND_GOOGLE_DRIVE_REDIRECT
|
||||||
|
)
|
||||||
|
flow.fetch_token(code=auth_code)
|
||||||
|
creds = flow.credentials
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(token_path), exist_ok=True)
|
||||||
|
with open(token_path, "w") as token_file:
|
||||||
|
token_file.write(creds.to_json())
|
||||||
|
|
||||||
|
if not get_drive_tokens(token_path):
|
||||||
|
raise PermissionError("Not able to access Google Drive.")
|
||||||
|
|
||||||
|
return creds
|
@@ -1,5 +1,4 @@
|
|||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,7 +1,10 @@
|
|||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.configs.constants import NO_AUTH_USER
|
||||||
|
from danswer.connectors.google_drive.connector_auth import get_auth_url
|
||||||
|
from danswer.connectors.google_drive.connector_auth import get_drive_tokens
|
||||||
|
from danswer.connectors.google_drive.connector_auth import save_access_tokens
|
||||||
|
from danswer.connectors.google_drive.connector_auth import verify_csrf
|
||||||
from danswer.connectors.models import InputType
|
from danswer.connectors.models import InputType
|
||||||
from danswer.connectors.slack.config import get_slack_config
|
from danswer.connectors.slack.config import get_slack_config
|
||||||
from danswer.connectors.slack.config import SlackConfig
|
from danswer.connectors.slack.config import SlackConfig
|
||||||
@@ -12,17 +15,43 @@ from danswer.db.models import IndexAttempt
|
|||||||
from danswer.db.models import IndexingStatus
|
from danswer.db.models import IndexingStatus
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
|
from danswer.server.models import AuthStatus
|
||||||
|
from danswer.server.models import AuthUrl
|
||||||
|
from danswer.server.models import GDriveCallback
|
||||||
|
from danswer.server.models import IndexAttemptSnapshot
|
||||||
|
from danswer.server.models import ListWebsiteIndexAttemptsResponse
|
||||||
|
from danswer.server.models import WebIndexAttemptRequest
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/admin")
|
router = APIRouter(prefix="/admin")
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/connectors/google-drive/check-auth", response_model=AuthStatus)
|
||||||
|
def check_drive_tokens(_: User = Depends(current_admin_user)) -> AuthStatus:
|
||||||
|
tokens = get_drive_tokens()
|
||||||
|
authenticated = tokens is not None
|
||||||
|
return AuthStatus(authenticated=authenticated)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/connectors/google-drive/authorize", response_model=AuthUrl)
|
||||||
|
def google_drive_auth(user: User = Depends(current_admin_user)) -> AuthUrl:
|
||||||
|
user_id = str(user.id) if user else NO_AUTH_USER
|
||||||
|
return AuthUrl(auth_url=get_auth_url(user_id))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/connectors/google-drive/callback", status_code=201)
|
||||||
|
def google_drive_callback(
|
||||||
|
callback: GDriveCallback = Depends(), user: User = Depends(current_admin_user)
|
||||||
|
) -> None:
|
||||||
|
user_id = str(user.id) if user else NO_AUTH_USER
|
||||||
|
verify_csrf(user_id, callback.state)
|
||||||
|
return save_access_tokens(callback.code)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/connectors/slack/config", response_model=SlackConfig)
|
@router.get("/connectors/slack/config", response_model=SlackConfig)
|
||||||
def fetch_slack_config(_: User = Depends(current_admin_user)) -> SlackConfig:
|
def fetch_slack_config(_: User = Depends(current_admin_user)) -> SlackConfig:
|
||||||
try:
|
try:
|
||||||
@@ -38,10 +67,6 @@ def modify_slack_config(
|
|||||||
update_slack_config(slack_config)
|
update_slack_config(slack_config)
|
||||||
|
|
||||||
|
|
||||||
class WebIndexAttemptRequest(BaseModel):
|
|
||||||
url: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/connectors/web/index-attempt", status_code=201)
|
@router.post("/connectors/web/index-attempt", status_code=201)
|
||||||
def index_website(
|
def index_website(
|
||||||
web_index_attempt_request: WebIndexAttemptRequest,
|
web_index_attempt_request: WebIndexAttemptRequest,
|
||||||
@@ -56,18 +81,6 @@ def index_website(
|
|||||||
insert_index_attempt(index_request)
|
insert_index_attempt(index_request)
|
||||||
|
|
||||||
|
|
||||||
class IndexAttemptSnapshot(BaseModel):
|
|
||||||
url: str
|
|
||||||
status: IndexingStatus
|
|
||||||
time_created: datetime
|
|
||||||
time_updated: datetime
|
|
||||||
docs_indexed: int
|
|
||||||
|
|
||||||
|
|
||||||
class ListWebsiteIndexAttemptsResponse(BaseModel):
|
|
||||||
index_attempts: list[IndexAttemptSnapshot]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/connectors/web/index-attempt")
|
@router.get("/connectors/web/index-attempt")
|
||||||
def list_website_index_attempts(
|
def list_website_index_attempts(
|
||||||
_: User = Depends(current_admin_user),
|
_: User = Depends(current_admin_user),
|
||||||
|
@@ -1,13 +1,29 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from danswer.datastores.interfaces import DatastoreFilter
|
from danswer.datastores.interfaces import DatastoreFilter
|
||||||
|
from danswer.db.models import IndexingStatus
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AuthStatus(BaseModel):
|
||||||
|
authenticated: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AuthUrl(BaseModel):
|
||||||
|
auth_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class GDriveCallback(BaseModel):
|
||||||
|
state: str
|
||||||
|
code: str
|
||||||
|
|
||||||
|
|
||||||
class UserRoleResponse(BaseModel):
|
class UserRoleResponse(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
|
|
||||||
|
|
||||||
class SearchDoc(BaseModel):
|
class SearchDoc(BaseModel):
|
||||||
semantic_name: str
|
semantic_identifier: str
|
||||||
link: str
|
link: str
|
||||||
blurb: str
|
blurb: str
|
||||||
source_type: str
|
source_type: str
|
||||||
@@ -31,3 +47,19 @@ class KeywordResponse(BaseModel):
|
|||||||
|
|
||||||
class UserByEmail(BaseModel):
|
class UserByEmail(BaseModel):
|
||||||
user_email: str
|
user_email: str
|
||||||
|
|
||||||
|
|
||||||
|
class WebIndexAttemptRequest(BaseModel):
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
class IndexAttemptSnapshot(BaseModel):
|
||||||
|
url: str
|
||||||
|
status: IndexingStatus
|
||||||
|
time_created: datetime
|
||||||
|
time_updated: datetime
|
||||||
|
docs_indexed: int
|
||||||
|
|
||||||
|
|
||||||
|
class ListWebsiteIndexAttemptsResponse(BaseModel):
|
||||||
|
index_attempts: list[IndexAttemptSnapshot]
|
||||||
|
@@ -77,7 +77,7 @@ def direct_qa(
|
|||||||
|
|
||||||
top_docs = [
|
top_docs = [
|
||||||
SearchDoc(
|
SearchDoc(
|
||||||
semantic_name=chunk.semantic_identifier,
|
semantic_identifier=chunk.semantic_identifier,
|
||||||
link=chunk.source_links.get("0") if chunk.source_links else None,
|
link=chunk.source_links.get("0") if chunk.source_links else None,
|
||||||
blurb=chunk.blurb,
|
blurb=chunk.blurb,
|
||||||
source_type=chunk.source_type,
|
source_type=chunk.source_type,
|
||||||
@@ -116,7 +116,7 @@ def stream_direct_qa(
|
|||||||
|
|
||||||
top_docs = [
|
top_docs = [
|
||||||
SearchDoc(
|
SearchDoc(
|
||||||
semantic_name=chunk.semantic_identifier,
|
semantic_identifier=chunk.semantic_identifier,
|
||||||
link=chunk.source_links.get("0") if chunk.source_links else None,
|
link=chunk.source_links.get("0") if chunk.source_links else None,
|
||||||
blurb=chunk.blurb,
|
blurb=chunk.blurb,
|
||||||
source_type=chunk.source_type,
|
source_type=chunk.source_type,
|
||||||
|
@@ -7,6 +7,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
|||||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||||
from danswer.connectors.github.batch import BatchGithubLoader
|
from danswer.connectors.github.batch import BatchGithubLoader
|
||||||
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
|
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
|
||||||
|
from danswer.connectors.google_drive.connector_auth import backend_get_credentials
|
||||||
from danswer.connectors.interfaces import PullLoader
|
from danswer.connectors.interfaces import PullLoader
|
||||||
from danswer.connectors.slack.batch import BatchSlackLoader
|
from danswer.connectors.slack.batch import BatchSlackLoader
|
||||||
from danswer.connectors.web.pull import WebLoader
|
from danswer.connectors.web.pull import WebLoader
|
||||||
@@ -71,6 +72,7 @@ def load_web_batch(url: str, qdrant_collection: str) -> None:
|
|||||||
|
|
||||||
def load_google_drive_batch(qdrant_collection: str) -> None:
|
def load_google_drive_batch(qdrant_collection: str) -> None:
|
||||||
logger.info("Loading documents from Google Drive.")
|
logger.info("Loading documents from Google Drive.")
|
||||||
|
backend_get_credentials()
|
||||||
load_batch(
|
load_batch(
|
||||||
BatchGoogleDriveLoader(batch_size=INDEX_BATCH_SIZE),
|
BatchGoogleDriveLoader(batch_size=INDEX_BATCH_SIZE),
|
||||||
DefaultChunker(),
|
DefaultChunker(),
|
||||||
|
Reference in New Issue
Block a user