Added sharepoint connector (#963)

This commit is contained in:
Hagen O'Neill 2024-01-25 16:16:10 -05:00 committed by GitHub
parent e94fd8b022
commit d6d83e79f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 590 additions and 2 deletions

1
backend/.gitignore vendored
View File

@ -8,3 +8,4 @@ api_keys.py
.env
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule

View File

@ -82,6 +82,7 @@ class DocumentSource(str, Enum):
GOOGLE_SITES = "google_sites"
ZENDESK = "zendesk"
LOOPIO = "loopio"
SHAREPOINT = "sharepoint"
class DocumentIndexType(str, Enum):

View File

@ -140,3 +140,19 @@ def read_file(
file_content_raw += line
return file_content_raw, metadata
def is_text_file_extension(file_name: str) -> bool:
extensions = (
".txt",
".mdx",
".md",
".conf",
".log",
".json",
".xml",
".yaml",
".yml",
".json",
)
return any(file_name.endswith(ext) for ext in extensions)

View File

@ -31,6 +31,7 @@ from danswer.connectors.slack.connector import SlackPollConnector
from danswer.connectors.web.connector import WebConnector
from danswer.connectors.zendesk.connector import ZendeskConnector
from danswer.connectors.zulip.connector import ZulipConnector
from danswer.connectors.sharepoint.connector import SharepointConnector
class ConnectorMissingException(Exception):
@ -68,6 +69,7 @@ def identify_connector_class(
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
DocumentSource.ZENDESK: ZendeskConnector,
DocumentSource.LOOPIO: LoopioConnector,
DocumentSource.SHAREPOINT: SharepointConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@ -0,0 +1,266 @@
import io
import os
import tempfile
from datetime import datetime
from datetime import timezone
from typing import Any
import docx # type: ignore
import msal # type: ignore
import openpyxl # type: ignore
import pptx # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
from office365.onedrive.sites.site import Site # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.file_utils import is_text_file_extension
from danswer.connectors.cross_connector_utils.file_utils import read_pdf_file
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
UNSUPPORTED_FILE_TYPE_CONTENT = "" # idea copied from the google drive side of things
logger = setup_logger()
def get_text_from_xlsx_driveitem(driveitem_object: DriveItem) -> str:
file_content = driveitem_object.get_content().execute_query().value
excel_file = io.BytesIO(file_content)
workbook = openpyxl.load_workbook(excel_file, read_only=True)
full_text = []
for sheet in workbook.worksheets:
sheet_string = "\n".join(
",".join(map(str, row))
for row in sheet.iter_rows(min_row=1, values_only=True)
)
full_text.append(sheet_string)
return "\n".join(full_text)
def get_text_from_docx_driveitem(driveitem_object: DriveItem) -> str:
file_content = driveitem_object.get_content().execute_query().value
full_text = []
with tempfile.TemporaryDirectory() as local_path:
with open(os.path.join(local_path, driveitem_object.name), "wb") as local_file:
local_file.write(file_content)
doc = docx.Document(local_file.name)
for para in doc.paragraphs:
full_text.append(para.text)
return "\n".join(full_text)
def get_text_from_pdf_driveitem(driveitem_object: DriveItem) -> str:
file_content = driveitem_object.get_content().execute_query().value
file_text = read_pdf_file(
file=io.BytesIO(file_content), file_name=driveitem_object.name
)
return file_text
def get_text_from_txt_driveitem(driveitem_object: DriveItem) -> str:
file_content: bytes = driveitem_object.get_content().execute_query().value
text_string = file_content.decode("utf-8")
return text_string
def get_text_from_pptx_driveitem(driveitem_object: DriveItem):
file_content = driveitem_object.get_content().execute_query().value
pptx_stream = io.BytesIO(file_content)
with tempfile.NamedTemporaryFile() as temp:
temp.write(pptx_stream.getvalue())
presentation = pptx.Presentation(temp.name)
extracted_text = ""
for slide_number, slide in enumerate(presentation.slides, start=1):
extracted_text += f"\nSlide {slide_number}:\n"
for shape in slide.shapes:
if hasattr(shape, "text"):
extracted_text += shape.text + "\n"
return extracted_text
class SharepointConnector(LoadConnector, PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
sites: list[str] = [],
) -> None:
self.batch_size = batch_size
self.graph_client: GraphClient | None = None
self.requested_site_list: list[str] = sites
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
aad_client_id = credentials["aad_client_id"]
aad_client_secret = credentials["aad_client_secret"]
aad_directory_id = credentials["aad_directory_id"]
def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
authority_url = f"https://login.microsoftonline.com/{aad_directory_id}"
app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=aad_client_id,
client_credential=aad_client_secret,
)
token = app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
)
return token
self.graph_client = GraphClient(_acquire_token_func)
return None
def get_all_driveitem_objects(
self,
site_object_list: list[Site],
start: datetime | None = None,
end: datetime | None = None,
) -> list[DriveItem]:
filter_str = ""
if start is not None and end is not None:
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
driveitem_list = []
for site_object in site_object_list:
site_list_objects = site_object.lists.get().execute_query()
for site_list_object in site_list_objects:
try:
query = site_list_object.drive.root.get_files(True)
if filter_str:
query = query.filter(filter_str)
driveitems = query.execute_query()
driveitem_list.extend(driveitems)
except Exception:
# Sites include things that do not contain .drive.root so this fails
# but this is fine, as there are no actually documents in those
pass
return driveitem_list
def get_all_site_objects(self) -> list[Site]:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
site_object_list: list[Site] = []
sites_object = self.graph_client.sites.get().execute_query()
if len(self.requested_site_list) > 0:
for requested_site in self.requested_site_list:
adjusted_string = "/" + requested_site.replace(" ", "")
for site_object in sites_object:
if site_object.web_url.endswith(adjusted_string):
site_object_list.append(site_object)
else:
site_object_list.extend(sites_object)
return site_object_list
def _fetch_from_sharepoint(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
site_object_list = self.get_all_site_objects()
driveitem_list = self.get_all_driveitem_objects(
site_object_list=site_object_list,
start=start,
end=end,
)
# goes over all urls, converts them into Document objects and then yjelds them in batches
doc_batch: list[Document] = []
batch_count = 0
for driveitem_object in driveitem_list:
doc_batch.append(
self.convert_driveitem_object_to_document(driveitem_object)
)
batch_count += 1
if batch_count >= self.batch_size:
yield doc_batch
batch_count = 0
doc_batch = []
yield doc_batch
def convert_driveitem_object_to_document(
self,
driveitem_object: DriveItem,
) -> Document:
file_text = self.extract_driveitem_text(driveitem_object)
doc = Document(
id=driveitem_object.id,
sections=[Section(link=driveitem_object.web_url, text=file_text)],
source=DocumentSource.SHAREPOINT,
semantic_identifier=driveitem_object.name,
doc_updated_at=driveitem_object.last_modified_datetime.replace(
tzinfo=timezone.utc
),
primary_owners=[
BasicExpertInfo(
display_name=driveitem_object.last_modified_by.user.displayName,
email=driveitem_object.last_modified_by.user.email,
)
],
metadata={},
)
return doc
def extract_driveitem_text(self, driveitem_object: DriveItem) -> str:
driveitem_name = driveitem_object.name
driveitem_text = UNSUPPORTED_FILE_TYPE_CONTENT
if driveitem_name.endswith(".docx"):
driveitem_text = get_text_from_docx_driveitem(driveitem_object)
elif driveitem_name.endswith(".pdf"):
driveitem_text = get_text_from_pdf_driveitem(driveitem_object)
elif driveitem_name.endswith(".xlsx"):
driveitem_text = get_text_from_xlsx_driveitem(driveitem_object)
elif driveitem_name.endswith(".pptx"):
driveitem_text = get_text_from_xlsx_driveitem(driveitem_object)
elif is_text_file_extension(driveitem_name):
driveitem_text = get_text_from_txt_driveitem(driveitem_object)
return driveitem_text
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_sharepoint()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
if __name__ == "__main__":
connector = SharepointConnector(sites=os.environ["SITES"].split(","))
connector.load_credentials(
{
"aad_client_id": os.environ["AAD_CLIENT_ID"],
"aad_client_secret": os.environ["AAD_CLIENT_SECRET"],
"aad_directory_id": os.environ["AAD_CLIENT_DIRECTORY_ID"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@ -6,7 +6,6 @@ celery==5.3.4
chardet==5.2.0
dask==2023.8.1
distributed==2023.8.1
python-dateutil==2.8.2
fastapi==0.103.0
fastapi-users==11.0.0
fastapi-users-db-sqlalchemy==5.0.0
@ -28,18 +27,22 @@ llama-index==0.9.8
Mako==1.2.4
nltk==3.8.1
docx2txt==0.8
openai==1.3.5
oauthlib==3.2.2
openai==1.3.5
openpyxl==3.1.2
playwright==1.40.0
psutil==5.9.5
psycopg2-binary==2.9.9
pycryptodome==3.19.1
pydantic==1.10.7
PyGithub==1.58.2
python-dateutil==2.8.2
python-gitlab==3.9.0
python-pptx==0.6.23
pypdf==3.17.0
pytest-mock==3.12.0
pytest-playwright==0.3.2
python-docx==1.1.0
python-dotenv==1.0.0
python-multipart==0.0.6
requests==2.31.0

BIN
web/public/Sharepoint.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

View File

@ -0,0 +1,268 @@
"use client";
import * as Yup from "yup";
import { TrashIcon, SharepointIcon } from "@/components/icons/icons"; // Make sure you have a Document360 icon
import { fetcher } from "@/lib/fetcher";
import useSWR, { useSWRConfig } from "swr";
import { LoadingAnimation } from "@/components/Loading";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import {
SharepointConfig,
SharepointCredentialJson,
ConnectorIndexingStatus,
Credential,
} from "@/lib/types"; // Modify or create these types as required
import { adminDeleteCredential, linkCredential } from "@/lib/credential";
import { CredentialForm } from "@/components/admin/connectors/CredentialForm";
import {
TextFormField,
TextArrayFieldBuilder,
} from "@/components/admin/connectors/Field";
import { ConnectorsTable } from "@/components/admin/connectors/table/ConnectorsTable";
import { ConnectorForm } from "@/components/admin/connectors/ConnectorForm";
import { usePublicCredentials } from "@/lib/hooks";
import { AdminPageTitle } from "@/components/admin/Title";
import { Card, Text, Title } from "@tremor/react";
const MainSection = () => {
const { mutate } = useSWRConfig();
const {
data: connectorIndexingStatuses,
isLoading: isConnectorIndexingStatusesLoading,
error: isConnectorIndexingStatusesError,
} = useSWR<ConnectorIndexingStatus<any, any>[]>(
"/api/manage/admin/connector/indexing-status",
fetcher
);
const {
data: credentialsData,
isLoading: isCredentialsLoading,
error: isCredentialsError,
refreshCredentials,
} = usePublicCredentials();
if (
(!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) ||
(!credentialsData && isCredentialsLoading)
) {
return <LoadingAnimation text="Loading" />;
}
if (isConnectorIndexingStatusesError || !connectorIndexingStatuses) {
return <div>Failed to load connectors</div>;
}
if (isCredentialsError || !credentialsData) {
return <div>Failed to load credentials</div>;
}
const sharepointConnectorIndexingStatuses: ConnectorIndexingStatus<
SharepointConfig,
SharepointCredentialJson
>[] = connectorIndexingStatuses.filter(
(connectorIndexingStatus) =>
connectorIndexingStatus.connector.source === "sharepoint"
);
const sharepointCredential: Credential<SharepointCredentialJson> | undefined =
credentialsData.find(
(credential) => credential.credential_json?.aad_client_id
);
return (
<>
<Title className="mb-2 mt-6 ml-auto mr-auto">
Step 1: Provide Sharepoint credentials
</Title>
{sharepointCredential ? (
<>
<div className="flex mb-1 text-sm">
<Text className="my-auto">Existing Azure AD Client ID: </Text>
<Text className="ml-1 italic my-auto">
{sharepointCredential.credential_json.aad_client_id}
</Text>
<button
className="ml-1 hover:bg-hover rounded p-1"
onClick={async () => {
await adminDeleteCredential(sharepointCredential.id);
refreshCredentials();
}}
>
<TrashIcon />
</button>
</div>
</>
) : (
<>
<Text className="mb-2">
To index Sharepoint, please provide Azure AD client ID, Client
Secret, and Directory ID.
</Text>
<Card className="mt-2">
<CredentialForm<SharepointCredentialJson>
formBody={
<>
<TextFormField
name="aad_client_id"
label="Azure AD Client ID:"
/>
<TextFormField
name="aad_directory_id"
label="Azure AD Directory ID:"
/>
<TextFormField
name="aad_client_secret"
label="Azure AD Client Secret:"
type="password"
/>
</>
}
validationSchema={Yup.object().shape({
aad_client_id: Yup.string().required(
"Please enter your Azure AD Client ID"
),
aad_directory_id: Yup.string().required(
"Please enter your Azure AD Directory ID"
),
aad_client_secret: Yup.string().required(
"Please enter your Azure AD Client Secret"
),
})}
initialValues={{
aad_client_id: "",
aad_directory_id: "",
aad_client_secret: "",
}}
onSubmit={(isSuccess) => {
if (isSuccess) {
refreshCredentials();
}
}}
/>
</Card>
</>
)}
<Title className="mb-2 mt-6 ml-auto mr-auto">
Step 2: Manage Sharepoint Connector
</Title>
{sharepointConnectorIndexingStatuses.length > 0 && (
<>
<Text className="mb-2">
We index the most recently updated tickets from each Sharepoint
instance listed below regularly.
</Text>
<Text className="mb-2">
The initial poll at this time retrieves tickets updated in the past
hour. All subsequent polls execute every ten minutes. This should be
configurable in the future.
</Text>
<div className="mb-2">
<ConnectorsTable<SharepointConfig, SharepointCredentialJson>
connectorIndexingStatuses={sharepointConnectorIndexingStatuses}
liveCredential={sharepointCredential}
getCredential={(credential) =>
credential.credential_json.aad_directory_id
}
onUpdate={() =>
mutate("/api/manage/admin/connector/indexing-status")
}
onCredentialLink={async (connectorId) => {
if (sharepointCredential) {
await linkCredential(connectorId, sharepointCredential.id);
mutate("/api/manage/admin/connector/indexing-status");
}
}}
specialColumns={[
{
header: "Sites Group Name",
key: "sites_group_name",
getValue: (ccPairStatus) => {
const connectorConfig =
ccPairStatus.connector.connector_specific_config;
return `${connectorConfig.sites_group_name}`;
},
},
{
header: "Connectors",
key: "connectors",
getValue: (ccPairStatus) => {
const connectorConfig =
ccPairStatus.connector.connector_specific_config;
return `${connectorConfig.sites}`;
},
},
]}
/>
</div>
</>
)}
{sharepointCredential ? (
<Card className="mt-4">
<ConnectorForm<SharepointConfig>
nameBuilder={(values) =>
`Sharepoint-${values.sites_group_name}`
}
ccPairNameBuilder={(values) =>
`Sharepoint ${values.sites_group_name}`
}
source="sharepoint"
inputType="poll"
formBody={
<>
<TextFormField name="sites_group_name" label="Sites Group Name:" />
</>
}
// formBody={<></>}
formBodyBuilder={TextArrayFieldBuilder({
name: "sites",
label: "Sites:",
subtext:
"Specify 0 or more sites to index. For example, specifying the site " +
"'support' for the 'danswerai' sharepoint will cause us to only index all content " +
"within the 'https://danswerai.sharepoint.com/sites/support' site. " +
"If no sites are specified, all sites in your organization will be indexed.",
})}
validationSchema={Yup.object().shape({
sites: Yup.array()
.of(Yup.string().required("Site names must be strings"))
.required(),
sites_group_name: Yup.string().required(
"Please enter the name you would like to give this group of sites e.g. engineering "
),
})}
initialValues={{
sites: [],
sites_group_name: "",
}}
credentialId={sharepointCredential.id}
refreshFreq={10 * 60} // 10 minutes
/>
</Card>
) : (
<Text>
Please provide all Azure info in Step 1 first! Once you're done with
that, you can then specify which Sharepoint sites you want to make
searchable.
</Text>
)}
</>
);
};
export default function Page() {
return (
<div className="mx-auto container">
<div className="mb-4">
<HealthCheckBanner />
</div>
<AdminPageTitle icon={<SharepointIcon size={32} />} title="Sharepoint" />
<MainSection />
</div>
);
}

View File

@ -50,6 +50,7 @@ import hubSpotIcon from "../../../public/HubSpot.png";
import document360Icon from "../../../public/Document360.png";
import googleSitesIcon from "../../../public/GoogleSites.png";
import zendeskIcon from "../../../public/Zendesk.svg";
import sharepointIcon from "../../../public/Sharepoint.png";
import { FaRobot } from "react-icons/fa";
interface IconProps {
@ -513,6 +514,18 @@ export const RequestTrackerIcon = ({
</div>
);
export const SharepointIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => (
<div
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
>
<Image src={sharepointIcon} alt="Logo" width="96" height="96" />
</div>
);
export const GongIcon = ({
size = 16,
className = defaultTailwindCSS,

View File

@ -18,6 +18,7 @@ import {
NotionIcon,
ProductboardIcon,
RequestTrackerIcon,
SharepointIcon,
SlabIcon,
SlackIcon,
ZendeskIcon,
@ -152,6 +153,11 @@ const SOURCE_METADATA_MAP: SourceMap = {
displayName: "Loopio",
category: SourceCategory.AppConnection,
},
sharepoint: {
icon: SharepointIcon,
displayName: "Sharepoint",
category: SourceCategory.AppConnection,
},
};
function fillSourceMetadata(

View File

@ -32,6 +32,7 @@ export type ValidSources =
| "file"
| "google_sites"
| "loopio"
| "sharepoint"
| "zendesk";
export type ValidInputTypes = "load_state" | "poll" | "event";
@ -104,6 +105,11 @@ export interface JiraConfig {
jira_project_url: string;
}
export interface SharepointConfig {
sites?: string[];
sites_group_name: string;
}
export interface ProductboardConfig {}
export interface SlackConfig {
@ -300,6 +306,12 @@ export interface ZendeskCredentialJson {
zendesk_token: string;
}
export interface SharepointCredentialJson {
aad_client_id: string;
aad_client_secret: string;
aad_directory_id: string;
}
// DELETION
export interface DeletionAttemptSnapshot {