Fix a few bugs with Google Drive polling (#250)

- Adds some offset to the `start` for the Google Drive connector to give time for `modifiedTime` to propagate so we don't miss updates
- Moves fetching folders into a separate call since folder `modifiedTime` doesn't get updated when a file in the folder is updated
- Uses `connector_credential_pair.last_successful_index_time` instead of `updated_at` to determine the `start` for poll connectors
This commit is contained in:
Chris Weaver
2023-07-28 18:27:32 -07:00
committed by GitHub
parent 62afbcb178
commit 3e8f5fa47e
6 changed files with 180 additions and 90 deletions

View File

@ -1,4 +1,6 @@
import time
from datetime import datetime
from datetime import timezone
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import LoadConnector
@ -7,6 +9,7 @@ from danswer.connectors.models import InputType
from danswer.datastores.indexing_pipeline import build_indexing_pipeline
from danswer.db.connector import disable_connector
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import get_db_current_time
@ -14,7 +17,6 @@ from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_successful_attempt
from danswer.db.index_attempt import get_last_successful_attempt_start_time
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
@ -62,6 +64,7 @@ def create_indexing_jobs(db_session: Session) -> None:
credential_id=attempt.credential_id,
attempt_status=IndexingStatus.FAILED,
net_docs=None,
run_dt=None,
db_session=db_session,
)
@ -82,6 +85,7 @@ def create_indexing_jobs(db_session: Session) -> None:
credential_id=credential.id,
attempt_status=IndexingStatus.NOT_STARTED,
net_docs=None,
run_dt=None,
db_session=db_session,
)
@ -122,6 +126,7 @@ def run_indexing_jobs(db_session: Session) -> None:
credential_id=db_credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
net_docs=None,
run_dt=None,
db_session=db_session,
)
@ -143,6 +148,11 @@ def run_indexing_jobs(db_session: Session) -> None:
net_doc_change = 0
try:
# "official" timestamp for this run
# used for setting time bounds when fetching updates from apps + is
# stored in the DB as the last successful run time if this run succeeds
run_time = time.time()
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
@ -154,14 +164,11 @@ def run_indexing_jobs(db_session: Session) -> None:
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
last_run_time = get_last_successful_attempt_start_time(
last_run_time = get_last_successful_attempt_time(
attempt.connector_id, attempt.credential_id, db_session
)
# Covers very unlikely case that time offset check from DB having tiny variations that coincide with
# a new document being created
safe_last_run_time = max(last_run_time - 1, 0.0)
doc_batch_generator = runnable_connector.poll_source(
safe_last_run_time, time.time()
start=last_run_time, end=run_time
)
else:
@ -184,6 +191,7 @@ def run_indexing_jobs(db_session: Session) -> None:
credential_id=db_credential.id,
attempt_status=IndexingStatus.SUCCESS,
net_docs=net_doc_change,
run_dt=run_dt,
db_session=db_session,
)
@ -197,6 +205,7 @@ def run_indexing_jobs(db_session: Session) -> None:
credential_id=db_credential.id,
attempt_status=IndexingStatus.FAILED,
net_docs=net_doc_change,
run_dt=run_dt,
db_session=db_session,
)

View File

@ -18,15 +18,16 @@ from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.utils import batch_generator
from danswer.utils.logger import setup_logger
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from googleapiclient import discovery # type: ignore
from PyPDF2 import PdfReader
logger = setup_logger()
# allow 10 minutes for modifiedTime to get propogated
DRIVE_START_TIME_OFFSET = 60 * 10
SCOPES = [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
@ -42,8 +43,36 @@ ID_KEY = "id"
LINK_KEY = "link"
TYPE_KEY = "type"
GoogleDriveFileType = dict[str, Any]
def get_folder_id(
def _run_drive_file_query(
service: discovery.Resource,
query: str,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]:
next_page_token = ""
while next_page_token is not None:
logger.debug(f"Running Google Drive fetch with query: {query}")
results = (
service.files()
.list(
pageSize=batch_size,
supportsAllDrives=include_shared,
fields="nextPageToken, files(mimeType, id, name, webViewLink)",
pageToken=next_page_token,
q=query,
)
.execute()
)
next_page_token = results.get("nextPageToken")
files = results["files"]
for file in files:
yield file
def _get_folder_id(
service: discovery.Resource, parent_id: str, folder_name: str
) -> str | None:
"""
@ -62,7 +91,59 @@ def get_folder_id(
return items[0]["id"] if items else None
def get_file_batches(
def _get_folders(
service: discovery.Resource,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]:
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
yield from _run_drive_file_query(
service=service,
query=query,
include_shared=include_shared,
batch_size=batch_size,
)
def _get_files(
service: discovery.Resource,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
supported_drive_doc_types: list[str] = SUPPORTED_DRIVE_DOC_TYPES,
batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
if time_range_start is not None:
time_start = (
datetime.datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
)
query += f"and modifiedTime >= '{time_start}' "
if time_range_end is not None:
time_stop = datetime.datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
query += f"and modifiedTime <= '{time_stop}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
files = _run_drive_file_query(
service=service,
query=query,
include_shared=include_shared,
batch_size=batch_size,
)
for file in files:
if file["mimeType"] in supported_drive_doc_types:
yield file
def get_all_files_batched(
service: discovery.Resource,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
batch_size: int = INDEX_BATCH_SIZE,
@ -71,56 +152,36 @@ def get_file_batches(
folder_id: str | None = None, # if specified, only fetches files within this folder
# if True, will fetch files in sub-folders of the specified folder ID. Only applies if folder_id is specified.
traverse_subfolders: bool = True,
) -> Generator[list[dict[str, str]], None, None]:
next_page_token = ""
subfolders: list[dict[str, str]] = []
while next_page_token is not None:
query = ""
if time_range_start is not None:
time_start = (
datetime.datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
)
query += f"modifiedTime >= '{time_start}' "
if time_range_end is not None:
time_stop = (
datetime.datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
)
query += f"and modifiedTime <= '{time_stop}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
logger.debug(f"Running Google Drive fetch with query: {query}")
results = (
service.files()
.list(
pageSize=batch_size,
supportsAllDrives=include_shared,
fields="nextPageToken, files(mimeType, id, name, webViewLink)",
pageToken=next_page_token,
q=query,
)
.execute()
)
next_page_token = results.get("nextPageToken")
files = results["files"]
valid_files: list[dict[str, str]] = []
for file in files:
if file["mimeType"] in SUPPORTED_DRIVE_DOC_TYPES:
valid_files.append(file)
elif file["mimeType"] == DRIVE_FOLDER_TYPE:
subfolders.append(file)
logger.info(
f"Parseable Documents in batch: {[file['name'] for file in valid_files]}"
)
yield valid_files
) -> Generator[list[GoogleDriveFileType], None, None]:
"""Gets all files matching the criteria specified by the args from Google Drive
in batches of size `batch_size`.
"""
valid_files = _get_files(
service=service,
time_range_start=time_range_start,
time_range_end=time_range_end,
folder_id=folder_id,
include_shared=include_shared,
batch_size=batch_size,
)
yield from batch_generator(
generator=valid_files,
batch_size=batch_size,
pre_batch_yield=lambda batch_files: logger.info(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
),
)
if traverse_subfolders:
subfolders = _get_folders(
service=service,
folder_id=folder_id,
include_shared=include_shared,
batch_size=batch_size,
)
for subfolder in subfolders:
logger.info("Fetching all files in subfolder: " + subfolder["name"])
yield from get_file_batches(
yield from get_all_files_batched(
service=service,
include_shared=include_shared,
batch_size=batch_size,
@ -190,7 +251,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
folder_names = path.split("/")
parent_id = "root"
for folder_name in folder_names:
found_parent_id = get_folder_id(
found_parent_id = _get_folder_id(
service=service, parent_id=parent_id, folder_name=folder_name
)
if found_parent_id is None:
@ -228,7 +289,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
file_batches = chain(
*[
get_file_batches(
get_all_files_batched(
service=service,
include_shared=self.include_shared,
batch_size=self.batch_size,
@ -264,4 +325,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
yield from self._fetch_docs_from_drive(start, end)
# need to subtract 10 minutes from start time to account for modifiedTime propogation
# if a document is modified, it takes some time for the API to reflect these changes
# if we do not have an offset, then we may "miss" the update when polling
yield from self._fetch_docs_from_drive(
max(start - DRIVE_START_TIME_OFFSET, 0, 0), end
)

View File

@ -0,0 +1,22 @@
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from itertools import islice
from typing import TypeVar
T = TypeVar("T")
def batch_generator(
generator: Iterator[T],
batch_size: int,
pre_batch_yield: Callable[[list[T]], None] | None = None,
) -> Generator[list[T], None, None]:
while True:
batch = list(islice(generator, batch_size))
if not batch:
return
if pre_batch_yield:
pre_batch_yield(batch)
yield batch

View File

@ -1,3 +1,5 @@
from datetime import datetime
from danswer.db.connector import fetch_connector_by_id
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.models import ConnectorCredentialPair
@ -6,7 +8,6 @@ from danswer.db.models import User
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from fastapi import HTTPException
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -35,11 +36,29 @@ def get_connector_credential_pair(
return result.scalar_one_or_none()
def get_last_successful_attempt_time(
connector_id: int,
credential_id: int,
db_session: Session,
) -> float:
connector_credential_pair = get_connector_credential_pair(
connector_id, credential_id, db_session
)
if (
connector_credential_pair is None
or connector_credential_pair.last_successful_index_time is None
):
return 0.0
return connector_credential_pair.last_successful_index_time.timestamp()
def update_connector_credential_pair(
connector_id: int,
credential_id: int,
attempt_status: IndexingStatus,
net_docs: int | None,
run_dt: datetime | None,
db_session: Session,
) -> None:
cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session)
@ -50,8 +69,10 @@ def update_connector_credential_pair(
)
return
cc_pair.last_attempt_status = attempt_status
if attempt_status == IndexingStatus.SUCCESS:
cc_pair.last_successful_index_time = func.now() # type:ignore
# simply don't update last_successful_index_time if run_dt is not specified
# at worst, this would result in re-indexing documents that were already indexed
if attempt_status == IndexingStatus.SUCCESS and run_dt is not None:
cc_pair.last_successful_index_time = run_dt
if net_docs is not None:
cc_pair.total_docs_indexed += net_docs
db_session.commit()

View File

@ -40,18 +40,6 @@ def get_db_current_time(db_session: Session) -> datetime:
return result
def translate_db_time_to_server_time(
db_time: datetime, db_session: Session
) -> datetime:
"""If a different database driver is used which does not include timezone info,
this should hit an exception rather than being wrong"""
server_now = datetime.now(timezone.utc)
db_now = get_db_current_time(db_session)
time_diff = server_now - db_now
logger.debug(f"Server time to DB time offset: {time_diff.total_seconds()} seconds")
return db_time + time_diff
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,

View File

@ -1,4 +1,3 @@
from danswer.db.engine import translate_db_time_to_server_time
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.utils.logger import setup_logger
@ -88,18 +87,3 @@ def get_last_successful_attempt(
stmt = stmt.order_by(desc(IndexAttempt.time_created))
return db_session.execute(stmt).scalars().first()
def get_last_successful_attempt_start_time(
connector_id: int,
credential_id: int,
db_session: Session,
) -> float:
"""Technically the start time is a bit later than creation but for intended use, it doesn't matter"""
last_indexing = get_last_successful_attempt(connector_id, credential_id, db_session)
if last_indexing is None:
return 0.0
last_index_start = translate_db_time_to_server_time(
last_indexing.time_created, db_session
)
return last_index_start.timestamp()