Fix a few bugs with Google Drive polling (#250)

- Adds some offset to the `start` for the Google Drive connector to give time for `modifiedTime` to propagate so we don't miss updates
- Moves fetching folders into a separate call since folder `modifiedTime` doesn't get updated when a file in the folder is updated
- Uses `connector_credential_pair.last_successful_index_time` instead of `updated_at` to determine the `start` for poll connectors
This commit is contained in:
Chris Weaver
2023-07-28 18:27:32 -07:00
committed by GitHub
parent 62afbcb178
commit 3e8f5fa47e
6 changed files with 180 additions and 90 deletions

View File

@ -1,4 +1,6 @@
import time import time
from datetime import datetime
from datetime import timezone
from danswer.connectors.factory import instantiate_connector from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import LoadConnector
@ -7,6 +9,7 @@ from danswer.connectors.models import InputType
from danswer.datastores.indexing_pipeline import build_indexing_pipeline from danswer.datastores.indexing_pipeline import build_indexing_pipeline
from danswer.db.connector import disable_connector from danswer.db.connector import disable_connector
from danswer.db.connector import fetch_connectors from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import get_db_current_time from danswer.db.engine import get_db_current_time
@ -14,7 +17,6 @@ from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_successful_attempt from danswer.db.index_attempt import get_last_successful_attempt
from danswer.db.index_attempt import get_last_successful_attempt_start_time
from danswer.db.index_attempt import get_not_started_index_attempts from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress from danswer.db.index_attempt import mark_attempt_in_progress
@ -62,6 +64,7 @@ def create_indexing_jobs(db_session: Session) -> None:
credential_id=attempt.credential_id, credential_id=attempt.credential_id,
attempt_status=IndexingStatus.FAILED, attempt_status=IndexingStatus.FAILED,
net_docs=None, net_docs=None,
run_dt=None,
db_session=db_session, db_session=db_session,
) )
@ -82,6 +85,7 @@ def create_indexing_jobs(db_session: Session) -> None:
credential_id=credential.id, credential_id=credential.id,
attempt_status=IndexingStatus.NOT_STARTED, attempt_status=IndexingStatus.NOT_STARTED,
net_docs=None, net_docs=None,
run_dt=None,
db_session=db_session, db_session=db_session,
) )
@ -122,6 +126,7 @@ def run_indexing_jobs(db_session: Session) -> None:
credential_id=db_credential.id, credential_id=db_credential.id,
attempt_status=IndexingStatus.IN_PROGRESS, attempt_status=IndexingStatus.IN_PROGRESS,
net_docs=None, net_docs=None,
run_dt=None,
db_session=db_session, db_session=db_session,
) )
@ -143,6 +148,11 @@ def run_indexing_jobs(db_session: Session) -> None:
net_doc_change = 0 net_doc_change = 0
try: try:
# "official" timestamp for this run
# used for setting time bounds when fetching updates from apps + is
# stored in the DB as the last successful run time if this run succeeds
run_time = time.time()
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
if task == InputType.LOAD_STATE: if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector) assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state() doc_batch_generator = runnable_connector.load_from_state()
@ -154,14 +164,11 @@ def run_indexing_jobs(db_session: Session) -> None:
f"Polling attempt {attempt.id} is missing connector_id or credential_id, " f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range." f"can't fetch time range."
) )
last_run_time = get_last_successful_attempt_start_time( last_run_time = get_last_successful_attempt_time(
attempt.connector_id, attempt.credential_id, db_session attempt.connector_id, attempt.credential_id, db_session
) )
# Covers very unlikely case that time offset check from DB having tiny variations that coincide with
# a new document being created
safe_last_run_time = max(last_run_time - 1, 0.0)
doc_batch_generator = runnable_connector.poll_source( doc_batch_generator = runnable_connector.poll_source(
safe_last_run_time, time.time() start=last_run_time, end=run_time
) )
else: else:
@ -184,6 +191,7 @@ def run_indexing_jobs(db_session: Session) -> None:
credential_id=db_credential.id, credential_id=db_credential.id,
attempt_status=IndexingStatus.SUCCESS, attempt_status=IndexingStatus.SUCCESS,
net_docs=net_doc_change, net_docs=net_doc_change,
run_dt=run_dt,
db_session=db_session, db_session=db_session,
) )
@ -197,6 +205,7 @@ def run_indexing_jobs(db_session: Session) -> None:
credential_id=db_credential.id, credential_id=db_credential.id,
attempt_status=IndexingStatus.FAILED, attempt_status=IndexingStatus.FAILED,
net_docs=net_doc_change, net_docs=net_doc_change,
run_dt=run_dt,
db_session=db_session, db_session=db_session,
) )

View File

@ -18,15 +18,16 @@ from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.connectors.utils import batch_generator
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials # type: ignore from google.oauth2.credentials import Credentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from googleapiclient import discovery # type: ignore from googleapiclient import discovery # type: ignore
from PyPDF2 import PdfReader from PyPDF2 import PdfReader
logger = setup_logger() logger = setup_logger()
# allow 10 minutes for modifiedTime to get propogated
DRIVE_START_TIME_OFFSET = 60 * 10
SCOPES = [ SCOPES = [
"https://www.googleapis.com/auth/drive.readonly", "https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly", "https://www.googleapis.com/auth/drive.metadata.readonly",
@ -42,8 +43,36 @@ ID_KEY = "id"
LINK_KEY = "link" LINK_KEY = "link"
TYPE_KEY = "type" TYPE_KEY = "type"
GoogleDriveFileType = dict[str, Any]
def get_folder_id(
def _run_drive_file_query(
service: discovery.Resource,
query: str,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]:
next_page_token = ""
while next_page_token is not None:
logger.debug(f"Running Google Drive fetch with query: {query}")
results = (
service.files()
.list(
pageSize=batch_size,
supportsAllDrives=include_shared,
fields="nextPageToken, files(mimeType, id, name, webViewLink)",
pageToken=next_page_token,
q=query,
)
.execute()
)
next_page_token = results.get("nextPageToken")
files = results["files"]
for file in files:
yield file
def _get_folder_id(
service: discovery.Resource, parent_id: str, folder_name: str service: discovery.Resource, parent_id: str, folder_name: str
) -> str | None: ) -> str | None:
""" """
@ -62,7 +91,59 @@ def get_folder_id(
return items[0]["id"] if items else None return items[0]["id"] if items else None
def get_file_batches( def _get_folders(
service: discovery.Resource,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]:
query = f"mimeType = '{DRIVE_FOLDER_TYPE}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
yield from _run_drive_file_query(
service=service,
query=query,
include_shared=include_shared,
batch_size=batch_size,
)
def _get_files(
service: discovery.Resource,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
folder_id: str | None = None, # if specified, only fetches files within this folder
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
supported_drive_doc_types: list[str] = SUPPORTED_DRIVE_DOC_TYPES,
batch_size: int = INDEX_BATCH_SIZE,
) -> Generator[GoogleDriveFileType, None, None]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' "
if time_range_start is not None:
time_start = (
datetime.datetime.utcfromtimestamp(time_range_start).isoformat() + "Z"
)
query += f"and modifiedTime >= '{time_start}' "
if time_range_end is not None:
time_stop = datetime.datetime.utcfromtimestamp(time_range_end).isoformat() + "Z"
query += f"and modifiedTime <= '{time_stop}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
files = _run_drive_file_query(
service=service,
query=query,
include_shared=include_shared,
batch_size=batch_size,
)
for file in files:
if file["mimeType"] in supported_drive_doc_types:
yield file
def get_all_files_batched(
service: discovery.Resource, service: discovery.Resource,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED, include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
batch_size: int = INDEX_BATCH_SIZE, batch_size: int = INDEX_BATCH_SIZE,
@ -71,56 +152,36 @@ def get_file_batches(
folder_id: str | None = None, # if specified, only fetches files within this folder folder_id: str | None = None, # if specified, only fetches files within this folder
# if True, will fetch files in sub-folders of the specified folder ID. Only applies if folder_id is specified. # if True, will fetch files in sub-folders of the specified folder ID. Only applies if folder_id is specified.
traverse_subfolders: bool = True, traverse_subfolders: bool = True,
) -> Generator[list[dict[str, str]], None, None]: ) -> Generator[list[GoogleDriveFileType], None, None]:
next_page_token = "" """Gets all files matching the criteria specified by the args from Google Drive
subfolders: list[dict[str, str]] = [] in batches of size `batch_size`.
while next_page_token is not None: """
query = "" valid_files = _get_files(
if time_range_start is not None: service=service,
time_start = ( time_range_start=time_range_start,
datetime.datetime.utcfromtimestamp(time_range_start).isoformat() + "Z" time_range_end=time_range_end,
folder_id=folder_id,
include_shared=include_shared,
batch_size=batch_size,
) )
query += f"modifiedTime >= '{time_start}' " yield from batch_generator(
if time_range_end is not None: generator=valid_files,
time_stop = ( batch_size=batch_size,
datetime.datetime.utcfromtimestamp(time_range_end).isoformat() + "Z" pre_batch_yield=lambda batch_files: logger.info(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
),
) )
query += f"and modifiedTime <= '{time_stop}' "
if folder_id:
query += f"and '{folder_id}' in parents "
query = query.rstrip() # remove the trailing space(s)
logger.debug(f"Running Google Drive fetch with query: {query}")
results = (
service.files()
.list(
pageSize=batch_size,
supportsAllDrives=include_shared,
fields="nextPageToken, files(mimeType, id, name, webViewLink)",
pageToken=next_page_token,
q=query,
)
.execute()
)
next_page_token = results.get("nextPageToken")
files = results["files"]
valid_files: list[dict[str, str]] = []
for file in files:
if file["mimeType"] in SUPPORTED_DRIVE_DOC_TYPES:
valid_files.append(file)
elif file["mimeType"] == DRIVE_FOLDER_TYPE:
subfolders.append(file)
logger.info(
f"Parseable Documents in batch: {[file['name'] for file in valid_files]}"
)
yield valid_files
if traverse_subfolders: if traverse_subfolders:
subfolders = _get_folders(
service=service,
folder_id=folder_id,
include_shared=include_shared,
batch_size=batch_size,
)
for subfolder in subfolders: for subfolder in subfolders:
logger.info("Fetching all files in subfolder: " + subfolder["name"]) logger.info("Fetching all files in subfolder: " + subfolder["name"])
yield from get_file_batches( yield from get_all_files_batched(
service=service, service=service,
include_shared=include_shared, include_shared=include_shared,
batch_size=batch_size, batch_size=batch_size,
@ -190,7 +251,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
folder_names = path.split("/") folder_names = path.split("/")
parent_id = "root" parent_id = "root"
for folder_name in folder_names: for folder_name in folder_names:
found_parent_id = get_folder_id( found_parent_id = _get_folder_id(
service=service, parent_id=parent_id, folder_name=folder_name service=service, parent_id=parent_id, folder_name=folder_name
) )
if found_parent_id is None: if found_parent_id is None:
@ -228,7 +289,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
file_batches = chain( file_batches = chain(
*[ *[
get_file_batches( get_all_files_batched(
service=service, service=service,
include_shared=self.include_shared, include_shared=self.include_shared,
batch_size=self.batch_size, batch_size=self.batch_size,
@ -264,4 +325,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
def poll_source( def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput: ) -> GenerateDocumentsOutput:
yield from self._fetch_docs_from_drive(start, end) # need to subtract 10 minutes from start time to account for modifiedTime propogation
# if a document is modified, it takes some time for the API to reflect these changes
# if we do not have an offset, then we may "miss" the update when polling
yield from self._fetch_docs_from_drive(
max(start - DRIVE_START_TIME_OFFSET, 0, 0), end
)

View File

@ -0,0 +1,22 @@
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from itertools import islice
from typing import TypeVar
T = TypeVar("T")
def batch_generator(
generator: Iterator[T],
batch_size: int,
pre_batch_yield: Callable[[list[T]], None] | None = None,
) -> Generator[list[T], None, None]:
while True:
batch = list(islice(generator, batch_size))
if not batch:
return
if pre_batch_yield:
pre_batch_yield(batch)
yield batch

View File

@ -1,3 +1,5 @@
from datetime import datetime
from danswer.db.connector import fetch_connector_by_id from danswer.db.connector import fetch_connector_by_id
from danswer.db.credentials import fetch_credential_by_id from danswer.db.credentials import fetch_credential_by_id
from danswer.db.models import ConnectorCredentialPair from danswer.db.models import ConnectorCredentialPair
@ -6,7 +8,6 @@ from danswer.db.models import User
from danswer.server.models import StatusResponse from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy import func
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -35,11 +36,29 @@ def get_connector_credential_pair(
return result.scalar_one_or_none() return result.scalar_one_or_none()
def get_last_successful_attempt_time(
connector_id: int,
credential_id: int,
db_session: Session,
) -> float:
connector_credential_pair = get_connector_credential_pair(
connector_id, credential_id, db_session
)
if (
connector_credential_pair is None
or connector_credential_pair.last_successful_index_time is None
):
return 0.0
return connector_credential_pair.last_successful_index_time.timestamp()
def update_connector_credential_pair( def update_connector_credential_pair(
connector_id: int, connector_id: int,
credential_id: int, credential_id: int,
attempt_status: IndexingStatus, attempt_status: IndexingStatus,
net_docs: int | None, net_docs: int | None,
run_dt: datetime | None,
db_session: Session, db_session: Session,
) -> None: ) -> None:
cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session) cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session)
@ -50,8 +69,10 @@ def update_connector_credential_pair(
) )
return return
cc_pair.last_attempt_status = attempt_status cc_pair.last_attempt_status = attempt_status
if attempt_status == IndexingStatus.SUCCESS: # simply don't update last_successful_index_time if run_dt is not specified
cc_pair.last_successful_index_time = func.now() # type:ignore # at worst, this would result in re-indexing documents that were already indexed
if attempt_status == IndexingStatus.SUCCESS and run_dt is not None:
cc_pair.last_successful_index_time = run_dt
if net_docs is not None: if net_docs is not None:
cc_pair.total_docs_indexed += net_docs cc_pair.total_docs_indexed += net_docs
db_session.commit() db_session.commit()

View File

@ -40,18 +40,6 @@ def get_db_current_time(db_session: Session) -> datetime:
return result return result
def translate_db_time_to_server_time(
db_time: datetime, db_session: Session
) -> datetime:
"""If a different database driver is used which does not include timezone info,
this should hit an exception rather than being wrong"""
server_now = datetime.now(timezone.utc)
db_now = get_db_current_time(db_session)
time_diff = server_now - db_now
logger.debug(f"Server time to DB time offset: {time_diff.total_seconds()} seconds")
return db_time + time_diff
def build_connection_string( def build_connection_string(
*, *,
db_api: str = ASYNC_DB_API, db_api: str = ASYNC_DB_API,

View File

@ -1,4 +1,3 @@
from danswer.db.engine import translate_db_time_to_server_time
from danswer.db.models import IndexAttempt from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus from danswer.db.models import IndexingStatus
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@ -88,18 +87,3 @@ def get_last_successful_attempt(
stmt = stmt.order_by(desc(IndexAttempt.time_created)) stmt = stmt.order_by(desc(IndexAttempt.time_created))
return db_session.execute(stmt).scalars().first() return db_session.execute(stmt).scalars().first()
def get_last_successful_attempt_start_time(
connector_id: int,
credential_id: int,
db_session: Session,
) -> float:
"""Technically the start time is a bit later than creation but for intended use, it doesn't matter"""
last_indexing = get_last_successful_attempt(connector_id, credential_id, db_session)
if last_indexing is None:
return 0.0
last_index_start = translate_db_time_to_server_time(
last_indexing.time_created, db_session
)
return last_index_start.timestamp()