Bugfix/slack stop 2 (#3916)

* use callback in slim doc functions

* more callbacks

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2025-02-08 15:45:41 -08:00 committed by GitHub
parent a222fae7c8
commit 4c184bb7f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 81 additions and 15 deletions

View File

@ -365,7 +365,9 @@ def confluence_doc_sync(
slim_docs = [] slim_docs = []
logger.debug("Fetching all slim documents from confluence") logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents(): for doc_batch in confluence_connector.retrieve_all_slim_documents(
callback=callback
):
logger.debug(f"Got {len(doc_batch)} slim documents from confluence") logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback: if callback:
if callback.should_stop(): if callback.should_stop():

View File

@ -15,6 +15,7 @@ logger = setup_logger()
def _get_slim_doc_generator( def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair, cc_pair: ConnectorCredentialPair,
gmail_connector: GmailConnector, gmail_connector: GmailConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
start_time = ( start_time = (
@ -24,7 +25,9 @@ def _get_slim_doc_generator(
) )
return gmail_connector.retrieve_all_slim_documents( return gmail_connector.retrieve_all_slim_documents(
start=start_time, end=current_time.timestamp() start=start_time,
end=current_time.timestamp(),
callback=callback,
) )
@ -40,7 +43,9 @@ def gmail_doc_sync(
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config) gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
gmail_connector.load_credentials(cc_pair.credential.credential_json) gmail_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector) slim_doc_generator = _get_slim_doc_generator(
cc_pair, gmail_connector, callback=callback
)
document_external_access: list[DocExternalAccess] = [] document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator: for slim_doc_batch in slim_doc_generator:

View File

@ -21,6 +21,7 @@ _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
def _get_slim_doc_generator( def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair, cc_pair: ConnectorCredentialPair,
google_drive_connector: GoogleDriveConnector, google_drive_connector: GoogleDriveConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
start_time = ( start_time = (
@ -30,7 +31,9 @@ def _get_slim_doc_generator(
) )
return google_drive_connector.retrieve_all_slim_documents( return google_drive_connector.retrieve_all_slim_documents(
start=start_time, end=current_time.timestamp() start=start_time,
end=current_time.timestamp(),
callback=callback,
) )

View File

@ -20,11 +20,18 @@ def _get_slack_document_ids_and_channels(
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config) slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json) slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents() slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
channel_doc_map: dict[str, list[str]] = {} channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator: for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch: for doc_metadata in doc_metadata_batch:
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
if channel_id not in channel_doc_map:
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_metadata.id)
if callback: if callback:
if callback.should_stop(): if callback.should_stop():
raise RuntimeError( raise RuntimeError(
@ -33,13 +40,6 @@ def _get_slack_document_ids_and_channels(
callback.progress("_get_slack_document_ids_and_channels", 1) callback.progress("_get_slack_document_ids_and_channels", 1)
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
if channel_id not in channel_doc_map:
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_metadata.id)
return channel_doc_map return channel_doc_map

View File

@ -27,6 +27,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document from onyx.connectors.models import Document
from onyx.connectors.models import Section from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@ -319,6 +320,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
doc_metadata_list: list[SlimDocument] = [] doc_metadata_list: list[SlimDocument] = []
@ -386,4 +388,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE] yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:] doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
yield doc_metadata_list yield doc_metadata_list

View File

@ -30,6 +30,7 @@ from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document from onyx.connectors.models import Document
from onyx.connectors.models import Section from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder from onyx.utils.retry_wrapper import retry_builder
@ -321,6 +322,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
self, self,
time_range_start: SecondsSinceUnixEpoch | None = None, time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None, time_range_end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
query = _build_time_range_query(time_range_start, time_range_end) query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = [] doc_batch = []
@ -343,6 +345,15 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
if len(doc_batch) > SLIM_BATCH_SIZE: if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch yield doc_batch
doc_batch = [] doc_batch = []
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
if doc_batch: if doc_batch:
yield doc_batch yield doc_batch
@ -368,9 +379,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
try: try:
yield from self._fetch_slim_threads(start, end) yield from self._fetch_slim_threads(start, end, callback=callback)
except Exception as e: except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e): if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e

View File

@ -42,6 +42,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector from onyx.connectors.interfaces import SlimConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder from onyx.utils.retry_wrapper import retry_builder
@ -564,6 +565,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
slim_batch = [] slim_batch = []
for file in self._fetch_drive_items( for file in self._fetch_drive_items(
@ -576,15 +578,26 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if len(slim_batch) >= SLIM_BATCH_SIZE: if len(slim_batch) >= SLIM_BATCH_SIZE:
yield slim_batch yield slim_batch
slim_batch = [] slim_batch = []
if callback:
if callback.should_stop():
raise RuntimeError(
"_extract_slim_docs_from_google_drive: Stop signal detected"
)
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch yield slim_batch
def retrieve_all_slim_documents( def retrieve_all_slim_documents(
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
try: try:
yield from self._extract_slim_docs_from_google_drive(start, end) yield from self._extract_slim_docs_from_google_drive(
start, end, callback=callback
)
except Exception as e: except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e): if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel
from onyx.configs.constants import DocumentSource from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
SecondsSinceUnixEpoch = float SecondsSinceUnixEpoch = float
@ -63,6 +64,7 @@ class SlimConnector(BaseConnector):
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
raise NotImplementedError raise NotImplementedError

View File

@ -29,6 +29,7 @@ from onyx.connectors.onyx_jira.utils import build_jira_url
from onyx.connectors.onyx_jira.utils import extract_jira_project from onyx.connectors.onyx_jira.utils import extract_jira_project
from onyx.connectors.onyx_jira.utils import extract_text_from_adf from onyx.connectors.onyx_jira.utils import extract_text_from_adf
from onyx.connectors.onyx_jira.utils import get_comment_strs from onyx.connectors.onyx_jira.utils import get_comment_strs
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
@ -245,6 +246,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
jql = f"project = {self.quoted_jira_project}" jql = f"project = {self.quoted_jira_project}"

View File

@ -21,6 +21,7 @@ from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_
from onyx.connectors.salesforce.sqlite_functions import get_record from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@ -176,6 +177,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
doc_metadata_list: list[SlimDocument] = [] doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list: for parent_object_type in self.parent_object_list:

View File

@ -21,6 +21,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document from onyx.connectors.models import Document
from onyx.connectors.models import Section from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
@ -242,6 +243,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = [] slim_doc_batch: list[SlimDocument] = []
for post_id in get_all_post_ids(self.slab_bot_token): for post_id in get_all_post_ids(self.slab_bot_token):

View File

@ -27,6 +27,7 @@ from onyx.connectors.slack.utils import get_message_link
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_slack_api_call_w_retries from onyx.connectors.slack.utils import make_slack_api_call_w_retries
from onyx.connectors.slack.utils import SlackTextCleaner from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
@ -98,6 +99,7 @@ def get_channel_messages(
channel: dict[str, Any], channel: dict[str, Any],
oldest: str | None = None, oldest: str | None = None,
latest: str | None = None, latest: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[list[MessageType], None, None]: ) -> Generator[list[MessageType], None, None]:
"""Get all messages in a channel""" """Get all messages in a channel"""
# join so that the bot can access messages # join so that the bot can access messages
@ -115,6 +117,11 @@ def get_channel_messages(
oldest=oldest, oldest=oldest,
latest=latest, latest=latest,
): ):
if callback:
if callback.should_stop():
raise RuntimeError("get_channel_messages: Stop signal detected")
callback.progress("get_channel_messages", 0)
yield cast(list[MessageType], result["messages"]) yield cast(list[MessageType], result["messages"])
@ -325,6 +332,7 @@ def _get_all_doc_ids(
channels: list[str] | None = None, channels: list[str] | None = None,
channel_name_regex_enabled: bool = False, channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter, msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
""" """
Get all document ids in the workspace, channel by channel Get all document ids in the workspace, channel by channel
@ -342,6 +350,7 @@ def _get_all_doc_ids(
channel_message_batches = get_channel_messages( channel_message_batches = get_channel_messages(
client=client, client=client,
channel=channel, channel=channel,
callback=callback,
) )
message_ts_set: set[str] = set() message_ts_set: set[str] = set()
@ -390,6 +399,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
if self.client is None: if self.client is None:
raise ConnectorMissingCredentialError("Slack") raise ConnectorMissingCredentialError("Slack")
@ -398,6 +408,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
client=self.client, client=self.client,
channels=self.channels, channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled, channel_name_regex_enabled=self.channel_regex_enabled,
callback=callback,
) )
def poll_source( def poll_source(

View File

@ -20,6 +20,7 @@ from onyx.connectors.models import Document
from onyx.connectors.models import Section from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument from onyx.connectors.models import SlimDocument
from onyx.file_processing.html_utils import parse_html_page_basic from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.retry_wrapper import retry_builder from onyx.utils.retry_wrapper import retry_builder
@ -405,6 +406,7 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
self, self,
start: SecondsSinceUnixEpoch | None = None, start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput: ) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = [] slim_doc_batch: list[SlimDocument] = []
if self.content_type == "articles": if self.content_type == "articles":