Background update improvements (#3)

* Adding full indexing to background loading

* Add oldest/latest support for slack pull
This commit is contained in:
Chris Weaver
2023-04-29 15:15:26 -07:00
committed by GitHub
parent ed8fe75dd3
commit f1936fb755
2 changed files with 23 additions and 7 deletions

View File

@ -5,6 +5,7 @@ from danswer.connectors.slack.config import get_pull_frequency
from danswer.connectors.slack.pull import SlackPullLoader from danswer.connectors.slack.pull import SlackPullLoader
from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
logger = setup_logger() logger = setup_logger()
@ -21,6 +22,7 @@ def run_update():
# TODO (chris): implement a more generic way to run updates # TODO (chris): implement a more generic way to run updates
# so we don't need to edit this file for future connectors # so we don't need to edit this file for future connectors
dynamic_config_store = get_dynamic_config_store() dynamic_config_store = get_dynamic_config_store()
indexing_pipeline = build_indexing_pipeline()
current_time = int(time.time()) current_time = int(time.time())
# Slack # Slack
@ -40,7 +42,7 @@ def run_update():
): ):
logger.info(f"Running slack pull from {last_pull or 0} to {current_time}") logger.info(f"Running slack pull from {last_pull or 0} to {current_time}")
for doc_batch in SlackPullLoader().load(last_pull or 0, current_time): for doc_batch in SlackPullLoader().load(last_pull or 0, current_time):
print(len(doc_batch)) indexing_pipeline(doc_batch)
dynamic_config_store.store(last_slack_pull_key, current_time) dynamic_config_store.store(last_slack_pull_key, current_time)

View File

@ -93,7 +93,10 @@ def get_channels(client: WebClient) -> list[dict[str, Any]]:
def get_channel_messages( def get_channel_messages(
client: WebClient, channel: dict[str, Any] client: WebClient,
channel: dict[str, Any],
oldest: str | None = None,
latest: str | None = None,
) -> list[MessageType]: ) -> list[MessageType]:
"""Get all messages in a channel""" """Get all messages in a channel"""
# join so that the bot can access messages # join so that the bot can access messages
@ -104,7 +107,10 @@ def get_channel_messages(
messages: list[MessageType] = [] messages: list[MessageType] = []
for result in _make_slack_api_call( for result in _make_slack_api_call(
client.conversations_history, channel=channel["id"] client.conversations_history,
channel=channel["id"],
oldest=oldest,
latest=latest,
): ):
messages.extend(result["messages"]) messages.extend(result["messages"])
return messages return messages
@ -127,12 +133,16 @@ def _default_msg_filter(message: MessageType) -> bool:
def get_all_threads( def get_all_threads(
client: WebClient, client: WebClient,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter, msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
oldest: str | None = None,
latest: str | None = None,
) -> dict[str, list[ThreadType]]: ) -> dict[str, list[ThreadType]]:
"""Get all threads in the workspace""" """Get all threads in the workspace"""
channels = get_channels(client) channels = get_channels(client)
channel_id_to_messages: dict[str, list[dict[str, Any]]] = {} channel_id_to_messages: dict[str, list[dict[str, Any]]] = {}
for channel in channels: for channel in channels:
channel_id_to_messages[channel["id"]] = get_channel_messages(client, channel) channel_id_to_messages[channel["id"]] = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
)
channel_to_threads: dict[str, list[ThreadType]] = {} channel_to_threads: dict[str, list[ThreadType]] = {}
for channel_id, messages in channel_id_to_messages.items(): for channel_id, messages in channel_id_to_messages.items():
@ -167,9 +177,13 @@ def thread_to_doc(channel_id: str, thread: ThreadType) -> Document:
) )
def get_all_docs(client: WebClient) -> list[Document]: def get_all_docs(
client: WebClient,
oldest: str | None = None,
latest: str | None = None,
) -> list[Document]:
"""Get all documents in the workspace""" """Get all documents in the workspace"""
channel_id_to_threads = get_all_threads(client) channel_id_to_threads = get_all_threads(client=client, oldest=oldest, latest=latest)
docs: list[Document] = [] docs: list[Document] = []
for channel_id, threads in channel_id_to_threads.items(): for channel_id, threads in channel_id_to_threads.items():
docs.extend(thread_to_doc(channel_id, thread) for thread in threads) docs.extend(thread_to_doc(channel_id, thread) for thread in threads)
@ -186,6 +200,6 @@ class SlackPullLoader(PullLoader):
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[List[Document], None, None]: ) -> Generator[List[Document], None, None]:
# TODO: make this respect start and end # TODO: make this respect start and end
all_docs = get_all_docs(self.client) all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end))
for i in range(0, len(all_docs), self.batch_size): for i in range(0, len(all_docs), self.batch_size):
yield all_docs[i : i + self.batch_size] yield all_docs[i : i + self.batch_size]