diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index 622dc4eadd..5a6ba4c6ed 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -142,6 +142,7 @@ class DocumentSource(str, Enum): OCI_STORAGE = "oci_storage" XENFORO = "xenforo" NOT_APPLICABLE = "not_applicable" + DISCORD = "discord" FRESHDESK = "freshdesk" FIREFLIES = "fireflies" EGNYTE = "egnyte" diff --git a/backend/onyx/connectors/discord/__init__.py b/backend/onyx/connectors/discord/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/onyx/connectors/discord/connector.py b/backend/onyx/connectors/discord/connector.py new file mode 100644 index 0000000000..5bc0031f54 --- /dev/null +++ b/backend/onyx/connectors/discord/connector.py @@ -0,0 +1,320 @@ +import asyncio +from collections.abc import AsyncIterable +from collections.abc import Iterable +from datetime import datetime +from datetime import timezone +from typing import Any + +from discord import Client +from discord.channel import TextChannel +from discord.channel import Thread +from discord.enums import MessageType +from discord.flags import Intents +from discord.message import Message as DiscordMessage + +from onyx.configs.app_configs import INDEX_BATCH_SIZE +from onyx.configs.constants import DocumentSource +from onyx.connectors.interfaces import GenerateDocumentsOutput +from onyx.connectors.interfaces import LoadConnector +from onyx.connectors.interfaces import PollConnector +from onyx.connectors.interfaces import SecondsSinceUnixEpoch +from onyx.connectors.models import ConnectorMissingCredentialError +from onyx.connectors.models import Document +from onyx.connectors.models import Section +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +_DISCORD_DOC_ID_PREFIX = "DISCORD_" +_SNIPPET_LENGTH = 30 + + +def _convert_message_to_document( + message: DiscordMessage, + sections: list[Section], +) -> Document: + """ + Convert a discord message to a document + Sections are collected before calling this function because it relies on async + calls to fetch the thread history if there is one + """ + + metadata: dict[str, str | list[str]] = {} + semantic_substring = "" + + # Only messages from TextChannels will make it here but we have to check for it anyways + if isinstance(message.channel, TextChannel) and ( + channel_name := message.channel.name + ): + metadata["Channel"] = channel_name + semantic_substring += f" in Channel: #{channel_name}" + + # Single messages dont have a title + title = "" + + # If there is a thread, add more detail to the metadata, title, and semantic identifier + if isinstance(message.channel, Thread): + # Threads do have a title + title = message.channel.name + + # If its a thread, update the metadata, title, and semantic_substring + metadata["Thread"] = title + + # Add more detail to the semantic identifier if available + semantic_substring += f" in Thread: {title}" + + snippet: str = ( + message.content[:_SNIPPET_LENGTH].rstrip() + "..." + if len(message.content) > _SNIPPET_LENGTH + else message.content + ) + + semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}" + + return Document( + id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}", + source=DocumentSource.DISCORD, + semantic_identifier=semantic_identifier, + doc_updated_at=message.edited_at, + title=title, + sections=sections, + metadata=metadata, + ) + + +async def _fetch_filtered_channels( + discord_client: Client, + server_ids: list[int] | None, + channel_names: list[str] | None, +) -> list[TextChannel]: + filtered_channels: list[TextChannel] = [] + + for channel in discord_client.get_all_channels(): + if not channel.permissions_for(channel.guild.me).read_message_history: + continue + if not isinstance(channel, TextChannel): + continue + if server_ids and len(server_ids) > 0 and channel.guild.id not in server_ids: + continue + if channel_names and channel.name not in channel_names: + continue + filtered_channels.append(channel) + + logger.info(f"Found {len(filtered_channels)} channels for the authenticated user") + return filtered_channels + + +async def _fetch_documents_from_channel( + channel: TextChannel, + start_time: datetime | None, + end_time: datetime | None, +) -> AsyncIterable[Document]: + # Discord's epoch starts at 2015-01-01 + discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc) + if start_time and start_time < discord_epoch: + start_time = discord_epoch + + async for channel_message in channel.history( + after=start_time, + before=end_time, + ): + # Skip messages that are not the default type + if channel_message.type != MessageType.default: + continue + + sections: list[Section] = [ + Section( + text=channel_message.content, + link=channel_message.jump_url, + ) + ] + + yield _convert_message_to_document(channel_message, sections) + + for active_thread in channel.threads: + async for thread_message in active_thread.history( + after=start_time, + before=end_time, + ): + # Skip messages that are not the default type + if thread_message.type != MessageType.default: + continue + + sections = [ + Section( + text=thread_message.content, + link=thread_message.jump_url, + ) + ] + + yield _convert_message_to_document(thread_message, sections) + + async for archived_thread in channel.archived_threads(): + async for thread_message in archived_thread.history( + after=start_time, + before=end_time, + ): + # Skip messages that are not the default type + if thread_message.type != MessageType.default: + continue + + sections = [ + Section( + text=thread_message.content, + link=thread_message.jump_url, + ) + ] + + yield _convert_message_to_document(thread_message, sections) + + +def _manage_async_retrieval( + token: str, + requested_start_date_string: str, + channel_names: list[str], + server_ids: list[int], + start: datetime | None = None, + end: datetime | None = None, +) -> Iterable[Document]: + # parse requested_start_date_string to datetime + pull_date: datetime | None = ( + datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace( + tzinfo=timezone.utc + ) + if requested_start_date_string + else None + ) + + # Set start_time to the later of start and pull_date, or whichever is provided + start_time = max(filter(None, [start, pull_date])) if start or pull_date else None + + end_time: datetime | None = end + + async def _async_fetch() -> AsyncIterable[Document]: + intents = Intents.default() + intents.message_content = True + async with Client(intents=intents) as discord_client: + asyncio.create_task(discord_client.start(token)) + await discord_client.wait_until_ready() + + filtered_channels: list[TextChannel] = await _fetch_filtered_channels( + discord_client=discord_client, + server_ids=server_ids, + channel_names=channel_names, + ) + + for channel in filtered_channels: + async for doc in _fetch_documents_from_channel( + channel=channel, + start_time=start_time, + end_time=end_time, + ): + yield doc + + def run_and_yield() -> Iterable[Document]: + loop = asyncio.new_event_loop() + try: + # Get the async generator + async_gen = _async_fetch() + # Convert to AsyncIterator + async_iter = async_gen.__aiter__() + while True: + try: + # Create a coroutine by calling anext with the async iterator + next_coro = anext(async_iter) + # Run the coroutine to get the next document + doc = loop.run_until_complete(next_coro) + yield doc + except StopAsyncIteration: + break + finally: + loop.close() + + return run_and_yield() + + +class DiscordConnector(PollConnector, LoadConnector): + def __init__( + self, + server_ids: list[str] = [], + channel_names: list[str] = [], + start_date: str | None = None, # YYYY-MM-DD + batch_size: int = INDEX_BATCH_SIZE, + ): + self.batch_size = batch_size + self.channel_names: list[str] = channel_names if channel_names else [] + self.server_ids: list[int] = ( + [int(server_id) for server_id in server_ids] if server_ids else [] + ) + self._discord_bot_token: str | None = None + self.requested_start_date_string: str = start_date or "" + + @property + def discord_bot_token(self) -> str: + if self._discord_bot_token is None: + raise ConnectorMissingCredentialError("Discord") + return self._discord_bot_token + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self._discord_bot_token = credentials["discord_bot_token"] + return None + + def _manage_doc_batching( + self, + start: datetime | None = None, + end: datetime | None = None, + ) -> GenerateDocumentsOutput: + doc_batch = [] + for doc in _manage_async_retrieval( + token=self.discord_bot_token, + requested_start_date_string=self.requested_start_date_string, + channel_names=self.channel_names, + server_ids=self.server_ids, + start=start, + end=end, + ): + doc_batch.append(doc) + if len(doc_batch) >= self.batch_size: + yield doc_batch + doc_batch = [] + + if doc_batch: + yield doc_batch + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + return self._manage_doc_batching( + datetime.fromtimestamp(start, tz=timezone.utc), + datetime.fromtimestamp(end, tz=timezone.utc), + ) + + def load_from_state(self) -> GenerateDocumentsOutput: + return self._manage_doc_batching(None, None) + + +if __name__ == "__main__": + import os + import time + + end = time.time() + # 1 day + start = end - 24 * 60 * 60 * 1 + # "1,2,3" + server_ids: str | None = os.environ.get("server_ids", None) + # "channel1,channel2" + channel_names: str | None = os.environ.get("channel_names", None) + + connector = DiscordConnector( + server_ids=server_ids.split(",") if server_ids else [], + channel_names=channel_names.split(",") if channel_names else [], + start_date=os.environ.get("start_date", None), + ) + connector.load_credentials( + {"discord_bot_token": os.environ.get("discord_bot_token")} + ) + + for doc_batch in connector.poll_source(start, end): + for doc in doc_batch: + print(doc) diff --git a/backend/onyx/connectors/factory.py b/backend/onyx/connectors/factory.py index 8a08689c74..c7bbc3d708 100644 --- a/backend/onyx/connectors/factory.py +++ b/backend/onyx/connectors/factory.py @@ -12,6 +12,7 @@ from onyx.connectors.blob.connector import BlobStorageConnector from onyx.connectors.bookstack.connector import BookstackConnector from onyx.connectors.clickup.connector import ClickupConnector from onyx.connectors.confluence.connector import ConfluenceConnector +from onyx.connectors.discord.connector import DiscordConnector from onyx.connectors.discourse.connector import DiscourseConnector from onyx.connectors.document360.connector import Document360Connector from onyx.connectors.dropbox.connector import DropboxConnector @@ -101,6 +102,7 @@ def identify_connector_class( DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector, DocumentSource.OCI_STORAGE: BlobStorageConnector, DocumentSource.XENFORO: XenforoConnector, + DocumentSource.DISCORD: DiscordConnector, DocumentSource.FRESHDESK: FreshdeskConnector, DocumentSource.FIREFLIES: FirefliesConnector, DocumentSource.EGNYTE: EgnyteConnector, diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index d4bd46c6e9..3bd82522f8 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -8,6 +8,7 @@ celery==5.5.0b4 chardet==5.2.0 dask==2023.8.1 ddtrace==2.6.5 +discord.py==2.4.0 distributed==2023.8.1 fastapi==0.109.2 fastapi-users==12.1.3 diff --git a/backend/tests/daily/connectors/discord/test_discord_connector.py b/backend/tests/daily/connectors/discord/test_discord_connector.py new file mode 100644 index 0000000000..98d66e49ed --- /dev/null +++ b/backend/tests/daily/connectors/discord/test_discord_connector.py @@ -0,0 +1,49 @@ +import os +import time + +import pytest + +from onyx.connectors.discord.connector import DiscordConnector +from onyx.connectors.models import Document + + +@pytest.fixture +def discord_connector() -> DiscordConnector: + server_ids: str | None = os.environ.get("server_ids", None) + channel_names: str | None = os.environ.get("channel_names", None) + + connector = DiscordConnector( + server_ids=server_ids.split(",") if server_ids else [], + channel_names=channel_names.split(",") if channel_names else [], + start_date=os.environ.get("start_date", None), + ) + connector.load_credentials( + { + "discord_bot_token": os.environ.get("DISCORD_BOT_TOKEN"), + } + ) + return connector + + +@pytest.mark.skip(reason="Test Discord is not setup yet!") +def test_discord_poll_connector(discord_connector: DiscordConnector) -> None: + end = time.time() + start = end - 24 * 60 * 60 * 15 # 1 day + + all_docs: list[Document] = [] + channels: set[str] = set() + threads: set[str] = set() + for doc_batch in discord_connector.poll_source(start, end): + for doc in doc_batch: + if "Channel" in doc.metadata: + assert isinstance(doc.metadata["Channel"], str) + channels.add(doc.metadata["Channel"]) + if "Thread" in doc.metadata: + assert isinstance(doc.metadata["Thread"], str) + threads.add(doc.metadata["Thread"]) + all_docs.append(doc) + + # might change based on the channels and servers being used + assert len(all_docs) == 10 + assert len(channels) == 2 + assert len(threads) == 2 diff --git a/web/public/discord.png b/web/public/discord.png new file mode 100644 index 0000000000..fb39b32278 Binary files /dev/null and b/web/public/discord.png differ diff --git a/web/src/components/icons/icons.tsx b/web/src/components/icons/icons.tsx index 484704e156..fb4dd96458 100644 --- a/web/src/components/icons/icons.tsx +++ b/web/src/components/icons/icons.tsx @@ -68,6 +68,7 @@ import zendeskIcon from "../../../public/Zendesk.svg"; import dropboxIcon from "../../../public/Dropbox.png"; import egnyteIcon from "../../../public/Egnyte.png"; import slackIcon from "../../../public/Slack.png"; +import discordIcon from "../../../public/Discord.png"; import airtableIcon from "../../../public/Airtable.svg"; import s3Icon from "../../../public/S3.png"; @@ -258,6 +259,20 @@ export const ColorSlackIcon = ({ ); }; +export const ColorDiscordIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( +
+ Logo +
+ ); +}; + export const LiteLLMIcon = ({ size = 16, className = defaultTailwindCSS, diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index d02504566a..84b8a71506 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -1031,6 +1031,36 @@ For example, specifying .*-support.* as a "channel" will cause the connector to ], advanced_values: [], }, + discord: { + description: "Configure Discord connector", + values: [], + advanced_values: [ + { + type: "list", + query: "Enter Server IDs to include:", + label: "Server IDs", + name: "server_ids", + description: `Specify 0 or more server ids to include. Only channels inside them will be used for indexing`, + optional: true, + }, + { + type: "list", + query: "Enter channel names to include:", + label: "Channels", + name: "channel_names", + description: `Specify 0 or more channels to index. For example, specifying the channel "support" will cause us to only index all content within the "#support" channel. If no channels are specified, all channels the bot has access to will be indexed.`, + optional: true, + }, + { + type: "text", + query: "Enter the Start Date:", + label: "Start Date", + name: "start_date", + description: `Only messages after this date will be indexed. Format: YYYY-MM-DD`, + optional: true, + }, + ], + }, freshdesk: { description: "Configure Freshdesk connector", values: [], diff --git a/web/src/lib/connectors/credentials.ts b/web/src/lib/connectors/credentials.ts index b1d1a18d89..bd9c5dfc85 100644 --- a/web/src/lib/connectors/credentials.ts +++ b/web/src/lib/connectors/credentials.ts @@ -195,6 +195,10 @@ export interface AxeroCredentialJson { axero_api_token: string; } +export interface DiscordCredentialJson { + discord_bot_token: string; +} + export interface FreshdeskCredentialJson { freshdesk_domain: string; freshdesk_password: string; @@ -335,6 +339,7 @@ export const credentialTemplates: Record = { web: null, not_applicable: null, ingestion_api: null, + discord: { discord_bot_token: "" } as DiscordCredentialJson, // NOTE: These are Special Cases google_drive: { google_tokens: "" } as GoogleDriveCredentialJson, @@ -368,6 +373,9 @@ export const credentialDisplayNames: Record = { // Slack slack_bot_token: "Slack Bot Token", + // Discord + discord_bot_token: "Discord Bot Token", + // Gmail and Google Drive google_tokens: "Google Oauth Tokens", google_service_account_key: "Google Service Account Key", diff --git a/web/src/lib/sources.ts b/web/src/lib/sources.ts index 664ffc839a..a9a323c09f 100644 --- a/web/src/lib/sources.ts +++ b/web/src/lib/sources.ts @@ -36,6 +36,7 @@ import { GoogleStorageIcon, ColorSlackIcon, XenforoIcon, + ColorDiscordIcon, FreshdeskIcon, FirefliesIcon, EgnyteIcon, @@ -80,6 +81,12 @@ export const SOURCE_METADATA_MAP: SourceMap = { docs: "https://docs.onyx.app/connectors/slack", oauthSupported: true, }, + discord: { + icon: ColorDiscordIcon, + displayName: "Discord", + category: SourceCategory.Messaging, + docs: "https://docs.onyx.app/connectors/discord", + }, gmail: { icon: GmailIcon, displayName: "Gmail", diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index a50e6f8f1d..cab013985d 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -314,6 +314,7 @@ export enum ValidSources { GoogleSites = "google_sites", Loopio = "loopio", Dropbox = "dropbox", + Discord = "discord", Salesforce = "salesforce", Sharepoint = "sharepoint", Teams = "teams",