diff --git a/backend/.gitignore b/backend/.gitignore index b365a5c89..5017ee720 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -8,3 +8,4 @@ api_keys.py .env vespa-app.zip dynamic_config_storage/ +celerybeat-schedule diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 8e6562e04..0151634ed 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -82,6 +82,7 @@ class DocumentSource(str, Enum): GOOGLE_SITES = "google_sites" ZENDESK = "zendesk" LOOPIO = "loopio" + SHAREPOINT = "sharepoint" class DocumentIndexType(str, Enum): diff --git a/backend/danswer/connectors/cross_connector_utils/file_utils.py b/backend/danswer/connectors/cross_connector_utils/file_utils.py index 6587cc4fa..b0a9c723f 100644 --- a/backend/danswer/connectors/cross_connector_utils/file_utils.py +++ b/backend/danswer/connectors/cross_connector_utils/file_utils.py @@ -140,3 +140,19 @@ def read_file( file_content_raw += line return file_content_raw, metadata + + +def is_text_file_extension(file_name: str) -> bool: + extensions = ( + ".txt", + ".mdx", + ".md", + ".conf", + ".log", + ".json", + ".xml", + ".yaml", + ".yml", + ".json", + ) + return any(file_name.endswith(ext) for ext in extensions) diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index f25aab615..0c794fdbe 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -31,6 +31,7 @@ from danswer.connectors.slack.connector import SlackPollConnector from danswer.connectors.web.connector import WebConnector from danswer.connectors.zendesk.connector import ZendeskConnector from danswer.connectors.zulip.connector import ZulipConnector +from danswer.connectors.sharepoint.connector import SharepointConnector class ConnectorMissingException(Exception): @@ -68,6 +69,7 @@ def identify_connector_class( DocumentSource.GOOGLE_SITES: GoogleSitesConnector, DocumentSource.ZENDESK: ZendeskConnector, DocumentSource.LOOPIO: LoopioConnector, + DocumentSource.SHAREPOINT: SharepointConnector, } connector_by_source = connector_map.get(source, {}) diff --git a/backend/danswer/connectors/sharepoint/__init__.py b/backend/danswer/connectors/sharepoint/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/danswer/connectors/sharepoint/connector.py b/backend/danswer/connectors/sharepoint/connector.py new file mode 100644 index 000000000..f4887bd8c --- /dev/null +++ b/backend/danswer/connectors/sharepoint/connector.py @@ -0,0 +1,266 @@ +import io +import os +import tempfile +from datetime import datetime +from datetime import timezone +from typing import Any + +import docx # type: ignore +import msal # type: ignore +import openpyxl # type: ignore +import pptx # type: ignore +from office365.graph_client import GraphClient # type: ignore +from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore +from office365.onedrive.sites.site import Site # type: ignore + +from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.constants import DocumentSource +from danswer.connectors.cross_connector_utils.file_utils import is_text_file_extension +from danswer.connectors.cross_connector_utils.file_utils import read_pdf_file +from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.models import BasicExpertInfo +from danswer.connectors.models import ConnectorMissingCredentialError +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.utils.logger import setup_logger + +UNSUPPORTED_FILE_TYPE_CONTENT = "" # idea copied from the google drive side of things + + +logger = setup_logger() + + +def get_text_from_xlsx_driveitem(driveitem_object: DriveItem) -> str: + file_content = driveitem_object.get_content().execute_query().value + excel_file = io.BytesIO(file_content) + workbook = openpyxl.load_workbook(excel_file, read_only=True) + + full_text = [] + for sheet in workbook.worksheets: + sheet_string = "\n".join( + ",".join(map(str, row)) + for row in sheet.iter_rows(min_row=1, values_only=True) + ) + full_text.append(sheet_string) + + return "\n".join(full_text) + + +def get_text_from_docx_driveitem(driveitem_object: DriveItem) -> str: + file_content = driveitem_object.get_content().execute_query().value + full_text = [] + + with tempfile.TemporaryDirectory() as local_path: + with open(os.path.join(local_path, driveitem_object.name), "wb") as local_file: + local_file.write(file_content) + doc = docx.Document(local_file.name) + for para in doc.paragraphs: + full_text.append(para.text) + return "\n".join(full_text) + + +def get_text_from_pdf_driveitem(driveitem_object: DriveItem) -> str: + file_content = driveitem_object.get_content().execute_query().value + file_text = read_pdf_file( + file=io.BytesIO(file_content), file_name=driveitem_object.name + ) + return file_text + + +def get_text_from_txt_driveitem(driveitem_object: DriveItem) -> str: + file_content: bytes = driveitem_object.get_content().execute_query().value + text_string = file_content.decode("utf-8") + return text_string + + +def get_text_from_pptx_driveitem(driveitem_object: DriveItem): + file_content = driveitem_object.get_content().execute_query().value + pptx_stream = io.BytesIO(file_content) + with tempfile.NamedTemporaryFile() as temp: + temp.write(pptx_stream.getvalue()) + presentation = pptx.Presentation(temp.name) + extracted_text = "" + for slide_number, slide in enumerate(presentation.slides, start=1): + extracted_text += f"\nSlide {slide_number}:\n" + + for shape in slide.shapes: + if hasattr(shape, "text"): + extracted_text += shape.text + "\n" + + return extracted_text + + +class SharepointConnector(LoadConnector, PollConnector): + def __init__( + self, + batch_size: int = INDEX_BATCH_SIZE, + sites: list[str] = [], + ) -> None: + self.batch_size = batch_size + self.graph_client: GraphClient | None = None + self.requested_site_list: list[str] = sites + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + aad_client_id = credentials["aad_client_id"] + aad_client_secret = credentials["aad_client_secret"] + aad_directory_id = credentials["aad_directory_id"] + + def _acquire_token_func() -> dict[str, Any]: + """ + Acquire token via MSAL + """ + authority_url = f"https://login.microsoftonline.com/{aad_directory_id}" + app = msal.ConfidentialClientApplication( + authority=authority_url, + client_id=aad_client_id, + client_credential=aad_client_secret, + ) + token = app.acquire_token_for_client( + scopes=["https://graph.microsoft.com/.default"] + ) + return token + + self.graph_client = GraphClient(_acquire_token_func) + return None + + def get_all_driveitem_objects( + self, + site_object_list: list[Site], + start: datetime | None = None, + end: datetime | None = None, + ) -> list[DriveItem]: + filter_str = "" + if start is not None and end is not None: + filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}" + + driveitem_list = [] + for site_object in site_object_list: + site_list_objects = site_object.lists.get().execute_query() + for site_list_object in site_list_objects: + try: + query = site_list_object.drive.root.get_files(True) + if filter_str: + query = query.filter(filter_str) + driveitems = query.execute_query() + driveitem_list.extend(driveitems) + except Exception: + # Sites include things that do not contain .drive.root so this fails + # but this is fine, as there are no actually documents in those + pass + + return driveitem_list + + def get_all_site_objects(self) -> list[Site]: + if self.graph_client is None: + raise ConnectorMissingCredentialError("Sharepoint") + + site_object_list: list[Site] = [] + + sites_object = self.graph_client.sites.get().execute_query() + + if len(self.requested_site_list) > 0: + for requested_site in self.requested_site_list: + adjusted_string = "/" + requested_site.replace(" ", "") + for site_object in sites_object: + if site_object.web_url.endswith(adjusted_string): + site_object_list.append(site_object) + else: + site_object_list.extend(sites_object) + + return site_object_list + + def _fetch_from_sharepoint( + self, start: datetime | None = None, end: datetime | None = None + ) -> GenerateDocumentsOutput: + if self.graph_client is None: + raise ConnectorMissingCredentialError("Sharepoint") + + site_object_list = self.get_all_site_objects() + + driveitem_list = self.get_all_driveitem_objects( + site_object_list=site_object_list, + start=start, + end=end, + ) + + # goes over all urls, converts them into Document objects and then yjelds them in batches + doc_batch: list[Document] = [] + batch_count = 0 + for driveitem_object in driveitem_list: + doc_batch.append( + self.convert_driveitem_object_to_document(driveitem_object) + ) + + batch_count += 1 + if batch_count >= self.batch_size: + yield doc_batch + batch_count = 0 + doc_batch = [] + yield doc_batch + + def convert_driveitem_object_to_document( + self, + driveitem_object: DriveItem, + ) -> Document: + file_text = self.extract_driveitem_text(driveitem_object) + doc = Document( + id=driveitem_object.id, + sections=[Section(link=driveitem_object.web_url, text=file_text)], + source=DocumentSource.SHAREPOINT, + semantic_identifier=driveitem_object.name, + doc_updated_at=driveitem_object.last_modified_datetime.replace( + tzinfo=timezone.utc + ), + primary_owners=[ + BasicExpertInfo( + display_name=driveitem_object.last_modified_by.user.displayName, + email=driveitem_object.last_modified_by.user.email, + ) + ], + metadata={}, + ) + return doc + + def extract_driveitem_text(self, driveitem_object: DriveItem) -> str: + driveitem_name = driveitem_object.name + driveitem_text = UNSUPPORTED_FILE_TYPE_CONTENT + + if driveitem_name.endswith(".docx"): + driveitem_text = get_text_from_docx_driveitem(driveitem_object) + elif driveitem_name.endswith(".pdf"): + driveitem_text = get_text_from_pdf_driveitem(driveitem_object) + elif driveitem_name.endswith(".xlsx"): + driveitem_text = get_text_from_xlsx_driveitem(driveitem_object) + elif driveitem_name.endswith(".pptx"): + driveitem_text = get_text_from_xlsx_driveitem(driveitem_object) + elif is_text_file_extension(driveitem_name): + driveitem_text = get_text_from_txt_driveitem(driveitem_object) + + return driveitem_text + + def load_from_state(self) -> GenerateDocumentsOutput: + return self._fetch_from_sharepoint() + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + start_datetime = datetime.utcfromtimestamp(start) + end_datetime = datetime.utcfromtimestamp(end) + return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime) + + +if __name__ == "__main__": + connector = SharepointConnector(sites=os.environ["SITES"].split(",")) + + connector.load_credentials( + { + "aad_client_id": os.environ["AAD_CLIENT_ID"], + "aad_client_secret": os.environ["AAD_CLIENT_SECRET"], + "aad_directory_id": os.environ["AAD_CLIENT_DIRECTORY_ID"], + } + ) + document_batches = connector.load_from_state() + print(next(document_batches)) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 8f0572f15..b48c10302 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -6,7 +6,6 @@ celery==5.3.4 chardet==5.2.0 dask==2023.8.1 distributed==2023.8.1 -python-dateutil==2.8.2 fastapi==0.103.0 fastapi-users==11.0.0 fastapi-users-db-sqlalchemy==5.0.0 @@ -28,18 +27,22 @@ llama-index==0.9.8 Mako==1.2.4 nltk==3.8.1 docx2txt==0.8 -openai==1.3.5 oauthlib==3.2.2 +openai==1.3.5 +openpyxl==3.1.2 playwright==1.40.0 psutil==5.9.5 psycopg2-binary==2.9.9 pycryptodome==3.19.1 pydantic==1.10.7 PyGithub==1.58.2 +python-dateutil==2.8.2 python-gitlab==3.9.0 +python-pptx==0.6.23 pypdf==3.17.0 pytest-mock==3.12.0 pytest-playwright==0.3.2 +python-docx==1.1.0 python-dotenv==1.0.0 python-multipart==0.0.6 requests==2.31.0 diff --git a/web/public/Sharepoint.png b/web/public/Sharepoint.png new file mode 100644 index 000000000..1dd15f7ff Binary files /dev/null and b/web/public/Sharepoint.png differ diff --git a/web/src/app/admin/connectors/sharepoint/page.tsx b/web/src/app/admin/connectors/sharepoint/page.tsx new file mode 100644 index 000000000..e7fb1a00b --- /dev/null +++ b/web/src/app/admin/connectors/sharepoint/page.tsx @@ -0,0 +1,268 @@ +"use client"; + +import * as Yup from "yup"; +import { TrashIcon, SharepointIcon } from "@/components/icons/icons"; // Make sure you have a Document360 icon +import { fetcher } from "@/lib/fetcher"; +import useSWR, { useSWRConfig } from "swr"; +import { LoadingAnimation } from "@/components/Loading"; +import { HealthCheckBanner } from "@/components/health/healthcheck"; +import { + SharepointConfig, + SharepointCredentialJson, + ConnectorIndexingStatus, + Credential, +} from "@/lib/types"; // Modify or create these types as required +import { adminDeleteCredential, linkCredential } from "@/lib/credential"; +import { CredentialForm } from "@/components/admin/connectors/CredentialForm"; +import { + TextFormField, + TextArrayFieldBuilder, +} from "@/components/admin/connectors/Field"; +import { ConnectorsTable } from "@/components/admin/connectors/table/ConnectorsTable"; +import { ConnectorForm } from "@/components/admin/connectors/ConnectorForm"; +import { usePublicCredentials } from "@/lib/hooks"; +import { AdminPageTitle } from "@/components/admin/Title"; +import { Card, Text, Title } from "@tremor/react"; + +const MainSection = () => { + const { mutate } = useSWRConfig(); + const { + data: connectorIndexingStatuses, + isLoading: isConnectorIndexingStatusesLoading, + error: isConnectorIndexingStatusesError, + } = useSWR[]>( + "/api/manage/admin/connector/indexing-status", + fetcher + ); + + const { + data: credentialsData, + isLoading: isCredentialsLoading, + error: isCredentialsError, + refreshCredentials, + } = usePublicCredentials(); + + if ( + (!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) || + (!credentialsData && isCredentialsLoading) + ) { + return ; + } + + if (isConnectorIndexingStatusesError || !connectorIndexingStatuses) { + return
Failed to load connectors
; + } + + if (isCredentialsError || !credentialsData) { + return
Failed to load credentials
; + } + + const sharepointConnectorIndexingStatuses: ConnectorIndexingStatus< + SharepointConfig, + SharepointCredentialJson + >[] = connectorIndexingStatuses.filter( + (connectorIndexingStatus) => + connectorIndexingStatus.connector.source === "sharepoint" + ); + + const sharepointCredential: Credential | undefined = + credentialsData.find( + (credential) => credential.credential_json?.aad_client_id + ); + + return ( + <> + + Step 1: Provide Sharepoint credentials + + {sharepointCredential ? ( + <> +
+ Existing Azure AD Client ID: + + {sharepointCredential.credential_json.aad_client_id} + + +
+ + ) : ( + <> + + To index Sharepoint, please provide Azure AD client ID, Client + Secret, and Directory ID. + + + + formBody={ + <> + + + + + } + validationSchema={Yup.object().shape({ + aad_client_id: Yup.string().required( + "Please enter your Azure AD Client ID" + ), + aad_directory_id: Yup.string().required( + "Please enter your Azure AD Directory ID" + ), + aad_client_secret: Yup.string().required( + "Please enter your Azure AD Client Secret" + ), + })} + initialValues={{ + aad_client_id: "", + aad_directory_id: "", + aad_client_secret: "", + }} + onSubmit={(isSuccess) => { + if (isSuccess) { + refreshCredentials(); + } + }} + /> + + + )} + + + Step 2: Manage Sharepoint Connector + + + {sharepointConnectorIndexingStatuses.length > 0 && ( + <> + + We index the most recently updated tickets from each Sharepoint + instance listed below regularly. + + + The initial poll at this time retrieves tickets updated in the past + hour. All subsequent polls execute every ten minutes. This should be + configurable in the future. + +
+ + connectorIndexingStatuses={sharepointConnectorIndexingStatuses} + liveCredential={sharepointCredential} + getCredential={(credential) => + credential.credential_json.aad_directory_id + } + onUpdate={() => + mutate("/api/manage/admin/connector/indexing-status") + } + onCredentialLink={async (connectorId) => { + if (sharepointCredential) { + await linkCredential(connectorId, sharepointCredential.id); + mutate("/api/manage/admin/connector/indexing-status"); + } + }} + specialColumns={[ + { + header: "Sites Group Name", + key: "sites_group_name", + getValue: (ccPairStatus) => { + const connectorConfig = + ccPairStatus.connector.connector_specific_config; + return `${connectorConfig.sites_group_name}`; + }, + }, + { + header: "Connectors", + key: "connectors", + getValue: (ccPairStatus) => { + const connectorConfig = + ccPairStatus.connector.connector_specific_config; + return `${connectorConfig.sites}`; + }, + }, + ]} + /> +
+ + )} + + {sharepointCredential ? ( + + + nameBuilder={(values) => + `Sharepoint-${values.sites_group_name}` + } + ccPairNameBuilder={(values) => + `Sharepoint ${values.sites_group_name}` + } + source="sharepoint" + inputType="poll" + formBody={ + <> + + + } + // formBody={<>} + formBodyBuilder={TextArrayFieldBuilder({ + name: "sites", + label: "Sites:", + subtext: + "Specify 0 or more sites to index. For example, specifying the site " + + "'support' for the 'danswerai' sharepoint will cause us to only index all content " + + "within the 'https://danswerai.sharepoint.com/sites/support' site. " + + "If no sites are specified, all sites in your organization will be indexed.", + })} + validationSchema={Yup.object().shape({ + sites: Yup.array() + .of(Yup.string().required("Site names must be strings")) + .required(), + sites_group_name: Yup.string().required( + "Please enter the name you would like to give this group of sites e.g. engineering " + ), + })} + initialValues={{ + sites: [], + sites_group_name: "", + }} + credentialId={sharepointCredential.id} + refreshFreq={10 * 60} // 10 minutes + /> + + ) : ( + + Please provide all Azure info in Step 1 first! Once you're done with + that, you can then specify which Sharepoint sites you want to make + searchable. + + )} + + ); +}; + +export default function Page() { + return ( +
+
+ +
+ + } title="Sharepoint" /> + + +
+ ); +} diff --git a/web/src/components/icons/icons.tsx b/web/src/components/icons/icons.tsx index cdd6c9388..e5ae456cf 100644 --- a/web/src/components/icons/icons.tsx +++ b/web/src/components/icons/icons.tsx @@ -50,6 +50,7 @@ import hubSpotIcon from "../../../public/HubSpot.png"; import document360Icon from "../../../public/Document360.png"; import googleSitesIcon from "../../../public/GoogleSites.png"; import zendeskIcon from "../../../public/Zendesk.svg"; +import sharepointIcon from "../../../public/Sharepoint.png"; import { FaRobot } from "react-icons/fa"; interface IconProps { @@ -513,6 +514,18 @@ export const RequestTrackerIcon = ({ ); +export const SharepointIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => ( +
+ Logo +
+); + export const GongIcon = ({ size = 16, className = defaultTailwindCSS, diff --git a/web/src/lib/sources.ts b/web/src/lib/sources.ts index 67b524669..d2aaab03f 100644 --- a/web/src/lib/sources.ts +++ b/web/src/lib/sources.ts @@ -18,6 +18,7 @@ import { NotionIcon, ProductboardIcon, RequestTrackerIcon, + SharepointIcon, SlabIcon, SlackIcon, ZendeskIcon, @@ -152,6 +153,11 @@ const SOURCE_METADATA_MAP: SourceMap = { displayName: "Loopio", category: SourceCategory.AppConnection, }, + sharepoint: { + icon: SharepointIcon, + displayName: "Sharepoint", + category: SourceCategory.AppConnection, + }, }; function fillSourceMetadata( diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 27de27bd8..c6841bb22 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -32,6 +32,7 @@ export type ValidSources = | "file" | "google_sites" | "loopio" + | "sharepoint" | "zendesk"; export type ValidInputTypes = "load_state" | "poll" | "event"; @@ -104,6 +105,11 @@ export interface JiraConfig { jira_project_url: string; } +export interface SharepointConfig { + sites?: string[]; + sites_group_name: string; +} + export interface ProductboardConfig {} export interface SlackConfig { @@ -300,6 +306,12 @@ export interface ZendeskCredentialJson { zendesk_token: string; } +export interface SharepointCredentialJson { + aad_client_id: string; + aad_client_secret: string; + aad_directory_id: string; +} + // DELETION export interface DeletionAttemptSnapshot {