Connector checkpointing (#3876)

* wip checkpointing/continue on failure

more stuff for checkpointing

Basic implementation

FE stuff

More checkpointing/failure handling

rebase

rebase

initial scaffolding for IT

IT to test checkpointing

Cleanup

cleanup

Fix it

Rebase

Add todo

Fix actions IT

Test more

Pagination + fixes + cleanup

Fix IT networking

fix it

* rebase

* Address misc comments

* Address comments

* Remove unused router

* rebase

* Fix mypy

* Fixes

* fix it

* Fix tests

* Add drop index

* Add retries

* reset lock timeout

* Try hard drop of schema

* Add timeout/retries to downgrade

* rebase

* test

* test

* test

* Close all connections

* test closing idle only

* Fix it

* fix

* try using null pool

* Test

* fix

* rebase

* log

* Fix

* apply null pool

* Fix other test

* Fix quality checks

* Test not using the fixture

* Fix ordering

* fix test

* Change pooling behavior
This commit is contained in:
Chris Weaver
2025-02-15 18:34:39 -08:00
committed by GitHub
parent bc087fc20e
commit f1fc8ac19b
68 changed files with 3333 additions and 1102 deletions

View File

@@ -1,10 +1,16 @@
import contextvars
import copy
import re
from collections.abc import Callable
from collections.abc import Generator
from concurrent.futures import as_completed
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypedDict
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
@@ -12,14 +18,18 @@ from slack_sdk.errors import SlackApiError
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.slack.utils import expert_info_from_slack_id
@@ -33,6 +43,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_SLACK_LIMIT = 900
ChannelType = dict[str, Any]
MessageType = dict[str, Any]
@@ -40,6 +52,13 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType]
class SlackCheckpointContent(TypedDict):
channel_ids: list[str]
channel_completion_map: dict[str, str]
current_channel: ChannelType | None
seen_thread_ts: list[str]
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
@@ -140,6 +159,10 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
return f"{channel_id}__{thread_ts}"
def thread_to_doc(
channel: ChannelType,
thread: ThreadType,
@@ -182,7 +205,7 @@ def thread_to_doc(
)
return Document(
id=f"{channel_id}__{thread[0]['ts']}",
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
sections=[
Section(
link=get_message_link(event=m, client=client, channel_id=channel_id),
@@ -267,64 +290,97 @@ def filter_channels(
]
def _get_all_docs(
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
"""Get a channel by its ID.
Args:
client: The Slack WebClient instance
channel_id: The ID of the channel to fetch
Returns:
The channel information
Raises:
SlackApiError: If the channel cannot be fetched
"""
response = make_slack_api_call_w_retries(
client.conversations_info,
channel=channel_id,
)
return cast(ChannelType, response["channel"])
def _get_messages(
channel: ChannelType,
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Generator[Document, None, None]:
"""Get all documents in the workspace, channel by channel"""
slack_cleaner = SlackTextCleaner(client=client)
) -> tuple[list[MessageType], bool]:
"""Slack goes from newest to oldest."""
# Cache to prevent refetching via API since users
user_cache: dict[str, BasicExpertInfo | None] = {}
# have to be in the channel in order to read messages
if not channel["is_member"]:
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
response = make_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
latest=latest,
limit=_SLACK_LIMIT,
)
response.validate()
for channel in filtered_channels:
channel_docs = 0
channel_message_batches = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
messages = cast(list[MessageType], response.get("messages", []))
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
"next_cursor", ""
)
has_more = bool(cursor)
return messages, has_more
def _message_to_doc(
message: MessageType,
client: WebClient,
channel: ChannelType,
slack_cleaner: SlackTextCleaner,
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Document | None:
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# skip threads we've already seen, since we've already processed all
# messages in that thread
if thread_ts in seen_thread_ts:
return None
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
elif not msg_filter_func(message):
filtered_thread = [message]
if filtered_thread:
return thread_to_doc(
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
client=client,
user_cache=user_cache,
)
seen_thread_ts: set[str] = set()
for message_batch in channel_message_batches:
for message in message_batch:
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# skip threads we've already seen, since we've already processed all
# messages in that thread
if thread_ts in seen_thread_ts:
continue
seen_thread_ts.add(thread_ts)
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
elif not msg_filter_func(message):
filtered_thread = [message]
if filtered_thread:
channel_docs += 1
yield thread_to_doc(
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
client=client,
user_cache=user_cache,
)
logger.info(
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
)
return None
def _get_all_doc_ids(
@@ -368,7 +424,7 @@ def _get_all_doc_ids(
for message_ts in message_ts_set:
channel_metadata_list.append(
SlimDocument(
id=f"{channel_id}__{message_ts}",
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
perm_sync_data={"channel_id": channel_id},
)
)
@@ -376,7 +432,51 @@ def _get_all_doc_ids(
yield channel_metadata_list
class SlackPollConnector(PollConnector, SlimConnector):
def _process_message(
message: MessageType,
client: WebClient,
channel: ChannelType,
slack_cleaner: SlackTextCleaner,
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
thread_ts = message.get("thread_ts")
try:
# causes random failures for testing checkpointing / continue on failure
# import random
# if random.random() > 0.95:
# raise RuntimeError("Random failure :P")
doc = _message_to_doc(
message=message,
client=client,
channel=channel,
slack_cleaner=slack_cleaner,
user_cache=user_cache,
seen_thread_ts=seen_thread_ts,
msg_filter_func=msg_filter_func,
)
return (doc, thread_ts, None)
except Exception as e:
logger.exception(f"Error processing message {message['ts']}")
return (
None,
thread_ts,
ConnectorFailure(
failed_document=DocumentFailure(
document_id=_build_doc_id(
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
),
document_link=get_message_link(message, client, channel["id"]),
),
failure_message=str(e),
exception=e,
),
)
class SlackConnector(SlimConnector, CheckpointConnector):
def __init__(
self,
channels: list[str] | None = None,
@@ -390,9 +490,14 @@ class SlackPollConnector(PollConnector, SlimConnector):
self.batch_size = batch_size
self.client: WebClient | None = None
# just used for efficiency
self.text_cleaner: SlackTextCleaner | None = None
self.user_cache: dict[str, BasicExpertInfo | None] = {}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
bot_token = credentials["slack_bot_token"]
self.client = WebClient(token=bot_token)
self.text_cleaner = SlackTextCleaner(client=self.client)
return None
def retrieve_all_slim_documents(
@@ -411,30 +516,155 @@ class SlackPollConnector(PollConnector, SlimConnector):
callback=callback,
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.client is None:
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
"""Rough outline:
Step 1: Get all channels, yield back Checkpoint.
Step 2: Loop through each channel. For each channel:
Step 2.1: Get messages within the time range.
Step 2.2: Process messages in parallel, yield back docs.
Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel.
Slack returns messages from newest to oldest, so we need to keep track of
the latest message we've seen in each channel.
Step 2.4: If there are no more messages in the channel, switch the current
channel to the next channel.
"""
if self.client is None or self.text_cleaner is None:
raise ConnectorMissingCredentialError("Slack")
documents: list[Document] = []
for document in _get_all_docs(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
# throw an error if we use 0.0 on an account without infinite data
# retention
oldest=str(start) if start else None,
latest=str(end),
):
documents.append(document)
if len(documents) >= self.batch_size:
yield documents
documents = []
checkpoint_content = cast(
SlackCheckpointContent,
(
copy.deepcopy(checkpoint.checkpoint_content)
or {
"channel_ids": None,
"channel_completion_map": {},
"current_channel": None,
"seen_thread_ts": [],
}
),
)
if documents:
yield documents
# if this is the very first time we've called this, need to
# get all relevant channels and save them into the checkpoint
if checkpoint_content["channel_ids"] is None:
raw_channels = get_channels(self.client)
filtered_channels = filter_channels(
raw_channels, self.channels, self.channel_regex_enabled
)
if len(filtered_channels) == 0:
return checkpoint
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
checkpoint_content["current_channel"] = filtered_channels[0]
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=True,
)
return checkpoint
final_channel_ids = checkpoint_content["channel_ids"]
channel = checkpoint_content["current_channel"]
if channel is None:
raise ValueError("current_channel key not found in checkpoint")
channel_id = channel["id"]
if channel_id not in final_channel_ids:
raise ValueError(f"Channel {channel_id} not found in checkpoint")
oldest = str(start) if start else None
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
try:
logger.debug(
f"Getting messages for channel {channel} within range {oldest} - {latest}"
)
message_batch, has_more_in_channel = _get_messages(
channel, self.client, oldest, latest
)
new_latest = message_batch[-1]["ts"] if message_batch else latest
# Process messages in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=8) as executor:
futures: list[Future] = []
for message in message_batch:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
futures.append(
executor.submit(
current_context.run,
_process_message,
message=message,
client=self.client,
channel=channel,
slack_cleaner=self.text_cleaner,
user_cache=self.user_cache,
seen_thread_ts=seen_thread_ts,
)
)
for future in as_completed(futures):
doc, thread_ts, failures = future.result()
if doc:
# handle race conditions here since this is single
# threaded. Multi-threaded _process_message reads from this
# but since this is single threaded, we won't run into simul
# writes. At worst, we can duplicate a thread, which will be
# deduped later on.
if thread_ts not in seen_thread_ts:
yield doc
if thread_ts:
seen_thread_ts.add(thread_ts)
elif failures:
for failure in failures:
yield failure
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
if has_more_in_channel:
checkpoint_content["current_channel"] = channel
else:
new_channel_id = next(
(
channel_id
for channel_id in final_channel_ids
if channel_id
not in checkpoint_content["channel_completion_map"]
),
None,
)
if new_channel_id:
new_channel = _get_channel_by_id(self.client, new_channel_id)
checkpoint_content["current_channel"] = new_channel
else:
checkpoint_content["current_channel"] = None
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=checkpoint_content["current_channel"] is not None,
)
return checkpoint
except Exception as e:
logger.exception(f"Error processing channel {channel['name']}")
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=channel["id"],
missed_time_range=(
datetime.fromtimestamp(start, tz=timezone.utc),
datetime.fromtimestamp(end, tz=timezone.utc),
),
),
failure_message=str(e),
exception=e,
)
return checkpoint
if __name__ == "__main__":
@@ -442,7 +672,7 @@ if __name__ == "__main__":
import time
slack_channel = os.environ.get("SLACK_CHANNEL")
connector = SlackPollConnector(
connector = SlackConnector(
channels=[slack_channel] if slack_channel else None,
)
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})
@@ -450,6 +680,17 @@ if __name__ == "__main__":
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
document_batches = connector.poll_source(one_day_ago, current)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
print(next(document_batches))
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
try:
for document_or_failure in gen:
if isinstance(document_or_failure, Document):
print(document_or_failure)
elif isinstance(document_or_failure, ConnectorFailure):
print(document_or_failure)
except StopIteration as e:
checkpoint = e.value
print("Next checkpoint:", checkpoint)
print("Next checkpoint:", checkpoint)

View File

@@ -34,9 +34,14 @@ def get_message_link(
) -> str:
channel_id = channel_id or event["channel"]
message_ts = event["ts"]
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
permalink = response["permalink"]
return permalink
message_ts_without_dot = message_ts.replace(".", "")
thread_ts = event.get("thread_ts")
base_url = get_base_url(client.token)
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
f"?thread_ts={thread_ts}" if thread_ts else ""
)
return link
def _make_slack_api_call_paginated(