mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-29 17:20:44 +02:00
Fix a few bugs with Google Drive polling (#250)
- Adds some offset to the `start` for the Google Drive connector to give time for `modifiedTime` to propagate so we don't miss updates - Moves fetching folders into a separate call since folder `modifiedTime` doesn't get updated when a file in the folder is updated - Uses `connector_credential_pair.last_successful_index_time` instead of `updated_at` to determine the `start` for poll connectors
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@ -7,6 +9,7 @@ from danswer.connectors.models import InputType
|
||||
from danswer.datastores.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.engine import get_db_current_time
|
||||
@ -14,7 +17,6 @@ from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_successful_attempt
|
||||
from danswer.db.index_attempt import get_last_successful_attempt_start_time
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
@ -62,6 +64,7 @@ def create_indexing_jobs(db_session: Session) -> None:
|
||||
credential_id=attempt.credential_id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
net_docs=None,
|
||||
run_dt=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@ -82,6 +85,7 @@ def create_indexing_jobs(db_session: Session) -> None:
|
||||
credential_id=credential.id,
|
||||
attempt_status=IndexingStatus.NOT_STARTED,
|
||||
net_docs=None,
|
||||
run_dt=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@ -122,6 +126,7 @@ def run_indexing_jobs(db_session: Session) -> None:
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.IN_PROGRESS,
|
||||
net_docs=None,
|
||||
run_dt=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@ -143,6 +148,11 @@ def run_indexing_jobs(db_session: Session) -> None:
|
||||
|
||||
net_doc_change = 0
|
||||
try:
|
||||
# "official" timestamp for this run
|
||||
# used for setting time bounds when fetching updates from apps + is
|
||||
# stored in the DB as the last successful run time if this run succeeds
|
||||
run_time = time.time()
|
||||
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
|
||||
if task == InputType.LOAD_STATE:
|
||||
assert isinstance(runnable_connector, LoadConnector)
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
@ -154,14 +164,11 @@ def run_indexing_jobs(db_session: Session) -> None:
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
)
|
||||
last_run_time = get_last_successful_attempt_start_time(
|
||||
last_run_time = get_last_successful_attempt_time(
|
||||
attempt.connector_id, attempt.credential_id, db_session
|
||||
)
|
||||
# Covers very unlikely case that time offset check from DB having tiny variations that coincide with
|
||||
# a new document being created
|
||||
safe_last_run_time = max(last_run_time - 1, 0.0)
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
safe_last_run_time, time.time()
|
||||
start=last_run_time, end=run_time
|
||||
)
|
||||
|
||||
else:
|
||||
@ -184,6 +191,7 @@ def run_indexing_jobs(db_session: Session) -> None:
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.SUCCESS,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_dt,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@ -197,6 +205,7 @@ def run_indexing_jobs(db_session: Session) -> None:
|
||||
credential_id=db_credential.id,
|
||||
attempt_status=IndexingStatus.FAILED,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_dt,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
@ -18,15 +18,16 @@ from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.utils import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from google.auth.transport.requests import Request # type: ignore
|
||||
from google.oauth2.credentials import Credentials # type: ignore
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from googleapiclient import discovery # type: ignore
|
||||
from PyPDF2 import PdfReader
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# allow 10 minutes for modifiedTime to get propogated
|
||||
DRIVE_START_TIME_OFFSET = 60 * 10
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
@ -42,8 +43,36 @@ ID_KEY = "id"
|
||||
LINK_KEY = "link"
|
||||
TYPE_KEY = "type"
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
def get_folder_id(
|
||||
|
||||
def _run_drive_file_query(
|
||||
service: discovery.Resource,
|
||||
query: str,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Generator[GoogleDriveFileType, None, None]:
|
||||
next_page_token = ""
|
||||
while next_page_token is not None:
|
||||
logger.debug(f"Running Google Drive fetch with query: {query}")
|
||||
results = (
|
||||
service.files()
|
||||
.list(
|
||||
pageSize=batch_size,
|
||||
supportsAllDrives=include_shared,
|
||||
fields="nextPageToken, files(mimeType, id, name, webViewLink)",
|
||||
pageToken=next_page_token,
|
||||
q=query,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
next_page_token = results.get("nextPageToken")
|
||||
files = results["files"]
|
||||
for file in files:
|
||||
yield file
|
||||
|
||||
|
||||
def _get_folder_id(
|
||||
service: discovery.Resource, parent_id: str, folder_name: str
|
||||
) -> str | None:
|
||||
"""
|
||||
@ -62,7 +91,59 @@ def get_folder_id(
|
||||
return items[0]["id"] if items else None
|
||||
|
||||
|
||||
def get_file_batches(
|
||||
def _get_folders(
|
||||
service: discovery.Resource,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Generator[GoogleDriveFileType, None, None]:
|
||||
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
|
||||
if folder_id:
|
||||
query += f"and '{folder_id}' in parents "
|
||||
query = query.rstrip() # remove the trailing space(s)
|
||||
|
||||
yield from _run_drive_file_query(
|
||||
service=service,
|
||||
query=query,
|
||||
include_shared=include_shared,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
|
||||
def _get_files(
|
||||
service: discovery.Resource,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
supported_drive_doc_types: list[str] = SUPPORTED_DRIVE_DOC_TYPES,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> Generator[GoogleDriveFileType, None, None]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
|
||||
if time_range_start is not None:
|
||||
time_start = (
|
||||
datetime.datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
|
||||
)
|
||||
query += f"and modifiedTime >= '{time_start}' "
|
||||
if time_range_end is not None:
|
||||
time_stop = datetime.datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
|
||||
query += f"and modifiedTime <= '{time_stop}' "
|
||||
if folder_id:
|
||||
query += f"and '{folder_id}' in parents "
|
||||
query = query.rstrip() # remove the trailing space(s)
|
||||
|
||||
files = _run_drive_file_query(
|
||||
service=service,
|
||||
query=query,
|
||||
include_shared=include_shared,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
for file in files:
|
||||
if file["mimeType"] in supported_drive_doc_types:
|
||||
yield file
|
||||
|
||||
|
||||
def get_all_files_batched(
|
||||
service: discovery.Resource,
|
||||
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@ -71,56 +152,36 @@ def get_file_batches(
|
||||
folder_id: str | None = None, # if specified, only fetches files within this folder
|
||||
# if True, will fetch files in sub-folders of the specified folder ID. Only applies if folder_id is specified.
|
||||
traverse_subfolders: bool = True,
|
||||
) -> Generator[list[dict[str, str]], None, None]:
|
||||
next_page_token = ""
|
||||
subfolders: list[dict[str, str]] = []
|
||||
while next_page_token is not None:
|
||||
query = ""
|
||||
if time_range_start is not None:
|
||||
time_start = (
|
||||
datetime.datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
|
||||
)
|
||||
query += f"modifiedTime >= '{time_start}' "
|
||||
if time_range_end is not None:
|
||||
time_stop = (
|
||||
datetime.datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
|
||||
)
|
||||
query += f"and modifiedTime <= '{time_stop}' "
|
||||
if folder_id:
|
||||
query += f"and '{folder_id}' in parents "
|
||||
query = query.rstrip() # remove the trailing space(s)
|
||||
|
||||
logger.debug(f"Running Google Drive fetch with query: {query}")
|
||||
|
||||
results = (
|
||||
service.files()
|
||||
.list(
|
||||
pageSize=batch_size,
|
||||
supportsAllDrives=include_shared,
|
||||
fields="nextPageToken, files(mimeType, id, name, webViewLink)",
|
||||
pageToken=next_page_token,
|
||||
q=query,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
next_page_token = results.get("nextPageToken")
|
||||
files = results["files"]
|
||||
valid_files: list[dict[str, str]] = []
|
||||
|
||||
for file in files:
|
||||
if file["mimeType"] in SUPPORTED_DRIVE_DOC_TYPES:
|
||||
valid_files.append(file)
|
||||
elif file["mimeType"] == DRIVE_FOLDER_TYPE:
|
||||
subfolders.append(file)
|
||||
logger.info(
|
||||
f"Parseable Documents in batch: {[file['name'] for file in valid_files]}"
|
||||
)
|
||||
yield valid_files
|
||||
) -> Generator[list[GoogleDriveFileType], None, None]:
|
||||
"""Gets all files matching the criteria specified by the args from Google Drive
|
||||
in batches of size `batch_size`.
|
||||
"""
|
||||
valid_files = _get_files(
|
||||
service=service,
|
||||
time_range_start=time_range_start,
|
||||
time_range_end=time_range_end,
|
||||
folder_id=folder_id,
|
||||
include_shared=include_shared,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
yield from batch_generator(
|
||||
generator=valid_files,
|
||||
batch_size=batch_size,
|
||||
pre_batch_yield=lambda batch_files: logger.info(
|
||||
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
|
||||
),
|
||||
)
|
||||
|
||||
if traverse_subfolders:
|
||||
subfolders = _get_folders(
|
||||
service=service,
|
||||
folder_id=folder_id,
|
||||
include_shared=include_shared,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
for subfolder in subfolders:
|
||||
logger.info("Fetching all files in subfolder: " + subfolder["name"])
|
||||
yield from get_file_batches(
|
||||
yield from get_all_files_batched(
|
||||
service=service,
|
||||
include_shared=include_shared,
|
||||
batch_size=batch_size,
|
||||
@ -190,7 +251,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
folder_names = path.split("/")
|
||||
parent_id = "root"
|
||||
for folder_name in folder_names:
|
||||
found_parent_id = get_folder_id(
|
||||
found_parent_id = _get_folder_id(
|
||||
service=service, parent_id=parent_id, folder_name=folder_name
|
||||
)
|
||||
if found_parent_id is None:
|
||||
@ -228,7 +289,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
|
||||
file_batches = chain(
|
||||
*[
|
||||
get_file_batches(
|
||||
get_all_files_batched(
|
||||
service=service,
|
||||
include_shared=self.include_shared,
|
||||
batch_size=self.batch_size,
|
||||
@ -264,4 +325,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
yield from self._fetch_docs_from_drive(start, end)
|
||||
# need to subtract 10 minutes from start time to account for modifiedTime propogation
|
||||
# if a document is modified, it takes some time for the API to reflect these changes
|
||||
# if we do not have an offset, then we may "miss" the update when polling
|
||||
yield from self._fetch_docs_from_drive(
|
||||
max(start - DRIVE_START_TIME_OFFSET, 0, 0), end
|
||||
)
|
||||
|
22
backend/danswer/connectors/utils.py
Normal file
22
backend/danswer/connectors/utils.py
Normal file
@ -0,0 +1,22 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from itertools import islice
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def batch_generator(
|
||||
generator: Iterator[T],
|
||||
batch_size: int,
|
||||
pre_batch_yield: Callable[[list[T]], None] | None = None,
|
||||
) -> Generator[list[T], None, None]:
|
||||
while True:
|
||||
batch = list(islice(generator, batch_size))
|
||||
if not batch:
|
||||
return
|
||||
|
||||
if pre_batch_yield:
|
||||
pre_batch_yield(batch)
|
||||
yield batch
|
@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
@ -6,7 +8,6 @@ from danswer.db.models import User
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -35,11 +36,29 @@ def get_connector_credential_pair(
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def get_last_successful_attempt_time(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
) -> float:
|
||||
connector_credential_pair = get_connector_credential_pair(
|
||||
connector_id, credential_id, db_session
|
||||
)
|
||||
if (
|
||||
connector_credential_pair is None
|
||||
or connector_credential_pair.last_successful_index_time is None
|
||||
):
|
||||
return 0.0
|
||||
|
||||
return connector_credential_pair.last_successful_index_time.timestamp()
|
||||
|
||||
|
||||
def update_connector_credential_pair(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
attempt_status: IndexingStatus,
|
||||
net_docs: int | None,
|
||||
run_dt: datetime | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session)
|
||||
@ -50,8 +69,10 @@ def update_connector_credential_pair(
|
||||
)
|
||||
return
|
||||
cc_pair.last_attempt_status = attempt_status
|
||||
if attempt_status == IndexingStatus.SUCCESS:
|
||||
cc_pair.last_successful_index_time = func.now() # type:ignore
|
||||
# simply don't update last_successful_index_time if run_dt is not specified
|
||||
# at worst, this would result in re-indexing documents that were already indexed
|
||||
if attempt_status == IndexingStatus.SUCCESS and run_dt is not None:
|
||||
cc_pair.last_successful_index_time = run_dt
|
||||
if net_docs is not None:
|
||||
cc_pair.total_docs_indexed += net_docs
|
||||
db_session.commit()
|
||||
|
@ -40,18 +40,6 @@ def get_db_current_time(db_session: Session) -> datetime:
|
||||
return result
|
||||
|
||||
|
||||
def translate_db_time_to_server_time(
|
||||
db_time: datetime, db_session: Session
|
||||
) -> datetime:
|
||||
"""If a different database driver is used which does not include timezone info,
|
||||
this should hit an exception rather than being wrong"""
|
||||
server_now = datetime.now(timezone.utc)
|
||||
db_now = get_db_current_time(db_session)
|
||||
time_diff = server_now - db_now
|
||||
logger.debug(f"Server time to DB time offset: {time_diff.total_seconds()} seconds")
|
||||
return db_time + time_diff
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
|
@ -1,4 +1,3 @@
|
||||
from danswer.db.engine import translate_db_time_to_server_time
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -88,18 +87,3 @@ def get_last_successful_attempt(
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
|
||||
return db_session.execute(stmt).scalars().first()
|
||||
|
||||
|
||||
def get_last_successful_attempt_start_time(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
) -> float:
|
||||
"""Technically the start time is a bit later than creation but for intended use, it doesn't matter"""
|
||||
last_indexing = get_last_successful_attempt(connector_id, credential_id, db_session)
|
||||
if last_indexing is None:
|
||||
return 0.0
|
||||
last_index_start = translate_db_time_to_server_time(
|
||||
last_indexing.time_created, db_session
|
||||
)
|
||||
return last_index_start.timestamp()
|
||||
|
Reference in New Issue
Block a user