Background update improvements (#3)

* Adding full indexing to background loading

* Add oldest/latest support for slack pull
This commit is contained in:
Chris Weaver 2023-04-29 15:15:26 -07:00 committed by GitHub
parent ed8fe75dd3
commit f1936fb755
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 7 deletions

View File

@ -5,6 +5,7 @@ from danswer.connectors.slack.config import get_pull_frequency
from danswer.connectors.slack.pull import SlackPullLoader
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logging import setup_logger
logger = setup_logger()
@ -21,6 +22,7 @@ def run_update():
# TODO (chris): implement a more generic way to run updates
# so we don't need to edit this file for future connectors
dynamic_config_store = get_dynamic_config_store()
indexing_pipeline = build_indexing_pipeline()
current_time = int(time.time())
# Slack
@ -40,7 +42,7 @@ def run_update():
):
logger.info(f"Running slack pull from {last_pull or 0} to {current_time}")
for doc_batch in SlackPullLoader().load(last_pull or 0, current_time):
print(len(doc_batch))
indexing_pipeline(doc_batch)
dynamic_config_store.store(last_slack_pull_key, current_time)

View File

@ -93,7 +93,10 @@ def get_channels(client: WebClient) -> list[dict[str, Any]]:
def get_channel_messages(
client: WebClient, channel: dict[str, Any]
client: WebClient,
channel: dict[str, Any],
oldest: str | None = None,
latest: str | None = None,
) -> list[MessageType]:
"""Get all messages in a channel"""
# join so that the bot can access messages
@ -104,7 +107,10 @@ def get_channel_messages(
messages: list[MessageType] = []
for result in _make_slack_api_call(
client.conversations_history, channel=channel["id"]
client.conversations_history,
channel=channel["id"],
oldest=oldest,
latest=latest,
):
messages.extend(result["messages"])
return messages
@ -127,12 +133,16 @@ def _default_msg_filter(message: MessageType) -> bool:
def get_all_threads(
client: WebClient,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
oldest: str | None = None,
latest: str | None = None,
) -> dict[str, list[ThreadType]]:
"""Get all threads in the workspace"""
channels = get_channels(client)
channel_id_to_messages: dict[str, list[dict[str, Any]]] = {}
for channel in channels:
channel_id_to_messages[channel["id"]] = get_channel_messages(client, channel)
channel_id_to_messages[channel["id"]] = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
)
channel_to_threads: dict[str, list[ThreadType]] = {}
for channel_id, messages in channel_id_to_messages.items():
@ -167,9 +177,13 @@ def thread_to_doc(channel_id: str, thread: ThreadType) -> Document:
)
def get_all_docs(client: WebClient) -> list[Document]:
def get_all_docs(
client: WebClient,
oldest: str | None = None,
latest: str | None = None,
) -> list[Document]:
"""Get all documents in the workspace"""
channel_id_to_threads = get_all_threads(client)
channel_id_to_threads = get_all_threads(client=client, oldest=oldest, latest=latest)
docs: list[Document] = []
for channel_id, threads in channel_id_to_threads.items():
docs.extend(thread_to_doc(channel_id, thread) for thread in threads)
@ -186,6 +200,6 @@ class SlackPullLoader(PullLoader):
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[List[Document], None, None]:
# TODO: make this respect start and end
all_docs = get_all_docs(self.client)
all_docs = get_all_docs(client=self.client, oldest=str(start), latest=str(end))
for i in range(0, len(all_docs), self.batch_size):
yield all_docs[i : i + self.batch_size]